Define Custom Metric Object
Note
This topic explains how to define custom deep learning metric
    objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics. You can also specify custom metrics using a function handle. For more
    information, see Define Custom Metric Function.
In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.
 How To Decide Which Metric Type To Use
 How To Decide Which Metric Type To Use
If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a
    function handle, then you can define your own custom metric object using this topic as a guide.
    After you define the custom metric, you can specify the metric as the Metrics name-value argument in the trainingOptions
    function.
To define a custom deep learning metric class, you can use the template in this example, which takes you through these steps:
- Name the metric — Give the metric a name so that you can use it in MATLAB®. 
- Declare the metric properties — Specify the public and private properties of the metric. 
- Create a constructor function — Specify how to construct the metric and set default values. 
- Create an initialization function (optional) — Specify how to initialize variables and run validation checks. 
- Create a reset function — Specify how to reset the metric properties between iterations. 
- Create an update function — Specify how to update metric properties between iterations. 
- Create an aggregation function — Specify how to aggregate the metric values across multiple instances of the metric object. 
- Create an evaluation function — Specify how to calculate the metric value for each iteration. 
This example shows how to create a custom false positive rate (FPR) metric. This equation defines the metric:
To see the completed metric class definition, see Completed Metric.
Metric Template
Copy the metric template into a new file in MATLAB. This template gives the structure of a metric class definition. It outlines:
- The - propertiesblock for public metric properties. This block must contain the- Nameproperty.
- The - propertiesblock for private metric properties. This block is optional.
- The metric constructor function. 
- The optional - initializefunction.
- The required - reset,- update,- aggregate, and- evaluatefunctions.
classdef myMetric < deep.Metric properties % (Required) Metric name. Name % Declare public metric properties here. % Any code can access these properties. Include here any properties % that you want to access or edit outside of the class. end properties (Access = private) % (Optional) Metric properties. % Declare private metric properties here. % Only members of the defining class can access these properties. % Include here properties that you do not want to edit outside % the class. end methods function metric = myMetric(args) % Create a myMetric object. % This function must have the same name as the class. % Define metric construction function here. end function metric = initialize(metric,batchY,batchT) % (Optional) Initialize metric. % % Use this function to initialize variables and run validation % checks. % % Inputs: % metric - Metric to initialize % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Initialized metric % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric initialization function here. end function metric = reset(metric) % Reset metric properties. % % Use this function to reset the metric properties between % iterations. % % Input: % metric - Metric containing properties to reset % % Output: % metric - Metric with reset properties % Define metric reset function here. end function metric = update(metric,batchY,batchT) % Update metric properties. % % Use this function to update metric properties that you use to % compute the final metric value. % % Inputs: % metric - Metric containing properties to update % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Metric with updated properties % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric update function here. end function metric = aggregate(metric,metric2) % Aggregate metric properties. % % Use this function to define how to aggregate properties from % multiple instances of the same metric object during parallel % training. % % Inputs: % metric - Metric containing properties to aggregate % metric2 - Metric containing properties to aggregate % % Output: % metric - Metric with aggregated properties % % Define metric aggregation function here. end function val = evaluate(metric) % Evaluate metric properties. % % Use this function to define how to use the metric properties % to compute the final metric value. % % Input: % metric - Metric containing properties to use to % evaluate the metric value % % Output: % val - Evaluated metric value % % To return multiple metric values, replace val with val1,... % valN. % Define metric evaluation function here. end end end
Metric Name
First, give the metric a name. In the first line of the class file, replace the
            existing name myMetric with fprMetric.
classdef fprMetric < deep.Metric ... end
Next, rename the myMetric constructor function (the first function
            in the methods section) so that it has the same name as the
            metric.
methods function metric = fprMetric(args) ... end ... end
Save Metric
 Save the metric class file in a new file with the name
                    fprMetric and the .m extension. The file
                name must match the metric name. To use the metric, you must save the file in the
                current folder or in a folder on the MATLAB path.
Declare Properties
Declare the metric properties in the property sections. You can
    specify attributes in the class definition to customize the behavior of properties for specific
    purposes. This template defines two property types by setting their Access
    attribute. Use the Access attribute to control access to specific class
    properties. 
- properties— Any code can access these properties. This is the default properties block with the default property attributes. By default, the- Accessattribute is- public.
- properties (Access = private)— Only members of the defining class can access the property.
Declare Public Properties
Declare public properties by listing them in the properties
                section. This section must contain the Name
                property.
properties % (Required) Metric name. Name end
Declare Private Properties
Declare private properties by listing them in the properties (Access =
                    private) section. This metric requires twp properties to evaluate the
                value: true negatives (TNs) and false positives (FPs). Only the functions within the
                metric class require access to these
                values.
properties (Access = private) % Define true negatives (TNs) and false positives (FPs). TrueNegatives FalsePositives end
Create Constructor Function
Create the function that constructs the metric and initializes the metric properties. If the software requires any variables to evaluate the metric value, then these variables must be inputs to the constructor function.
The FPR score metric constructor function requires the Name,
                NetworkOutput, and Maximize arguments.
            These arguments are optional when you use the constructor to create a metric object.
            Specify an args input to the fprMetric function
            that corresponds to the optional name-value arguments. Add a comment to explain the
            syntax of the function.
            
function metric = fprMetric(args) % metric = fprMetric creates an fprMetric metric object. % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0) % also specifies the optional Name option. By default, % the metric name is "FPR". By default, % the NetworkOutput is [], which corresponds to using all of % the network outputs. Maximize is set to 0 as the optimal value % occurs when the FPR is minimized. ... end
Next, set the default values for the metric properties. Parse the input arguments
            using an arguments block. Specify the default metric name as
                "FPR", the default network output as [], and
            the Maximize property as 0. The metric name appears in plots and
            verbose
            output.
function metric = fprMetric(args) ... arguments args.Name = "FPR" args.NetworkOutput = [] args.Maximize = 0 end ... end
Set the properties of the metric.
function metric = fprMetric(args) ... % Set the metric name. metric.Name = args.Name; % To support this metric for use with multi-output networks, set % the network output. metric.NetworkOutput = args.NetworkOutput; % To support this metric for early stopping and returning the % best network, set the maximize property. metric.Maximize = args.Maximize; end
View the completed constructor function. With this constructor function, the command
                fprMetric(Name="fpr") creates an FPR metric object with the name
                "fpr".
function metric = fprMetric(args) % metric = fprMetric creates an fprMetric metric object. % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0) % also specifies the optional Name option. By default, % the metric name is "FPR". By default, % the NetworkOutput is [], which corresponds to using all of % the network outputs. Maximize is set to 0 as the optimal value % occurs when the FPR is minimized. arguments args.Name = "FPR" args.NetworkOutput = [] args.Maximize = 1 end % Set the metric name. metric.Name = args.Name; % To support this metric for use with multi-output networks, set % the network output. metric.NetworkOutput = args.NetworkOutput; % To support this metric for early stopping and returning the % best network, set the maximize property. metric.Maximize = args.Maximize; end
Create Initialization Function
Create the optional function that initializes variables and runs validation checks.
            For this example, the metric does not need the initialize function,
            so you can delete it. For an example of an initialize function, see Initialization Function.
Create Reset Function
Create the function that resets the metric properties. The software calls this function before each iteration. For the FPR score metric, reset the TN and FP values to zero at the start of each iteration.
function metric = reset(metric) % metric = reset(metric) resets the metric properties. metric.TrueNegatives = 0; metric.FalsePositives = 0; end
Create Update Function
Create the function that updates the metric properties that you use to compute the FPR score value. The software calls this function in each training and validation mini-batch.
In the update function, define these steps: 
- Find the maximum score for each observation. The maximum score corresponds to the predicted class for each observation. 
- Find the TN and FP values. 
- Add the batch TN and FP values to the running total number of TNs and FPs. 
function metric = update(metric,batchY,batchT) % metric = update(metric,batchY,batchT) updates the metric % properties. % Find the channel (class) dimension. cDim = finddim(batchY,"C"); % Find the maximum score, which corresponds to the predicted % class. Set the predicted class to 1 and all other classes to 0. batchY = batchY == max(batchY,[],cDim); % Find the TN and FP values for this batch. batchTrueNegatives = sum(~batchY & ~batchT, 2); batchFalsePositives = sum(batchY & ~batchT, 2); % Add the batch values to the running totals and update the metric % properties. metric.TrueNegatives = metric.TrueNegatives + batchTrueNegatives; metric.FalsePositives = metric.FalsePositives + batchFalsePositives; end
For categorical targets, the layout of the targets that the software passes to the metric depends on which function you want to use the metric with.
- When using the metric with - trainnetand the targets are categorical arrays, if the loss function is- "index-crossentropy", then the software automatically converts the targets to numeric class indices and passes them to the metric. For other loss functions, the software converts the targets to one-hot encoded vectors and passes them to the metric.
- When using the metric with - testnetand the targets are categorical arrays, if the specified metrics include- "index-crossentropy"but do not include- "crossentropy", then the software converts the targets to numeric class indices and passes them to the metric. Otherwise, the software converts the targets to one-hot encoded vectors and passes them to the metric.
Create Aggregation Function
Create the function that specifies how to combine the metric values and properties
            across multiple instances of the metric. For example, the aggregate
            function defines how to aggregate properties from multiple instances of the same metric
            object during parallel training. 
For this example, to combine the TN and FP values, add the values from each metric instance.
function metric = aggregate(metric,metric2) % metric = aggregate(metric,metric2) aggregates the metric % properties across two instances of the metric. metric.TrueNegatives = metric.TrueNegatives + metric2.TrueNegatives; metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives; end
Create Evaluation Function
Create the function that specifies how to compute the metric value in each iteration. This equation defines the FPR metric as:
 Implement this equation in the evaluate
            function. Find the macro average by taking the average across all the
            classes.
function val = evaluate(metric) % val = evaluate(metric) uses the properties in metric to return the % evaluated metric value. % Extract TN and FP values. tn = metric.TrueNegatives; fp = metric.FalsePositives; % Compute the FPR value. val = mean(fp/(fp+tn+eps)); end
As the denominator value of this metric can be zero, add eps to the
            denominator to prevent the metric returning a NaN value.
Completed Metric
View the completed metric class file.
Note
For more information about when the software calls each function in the class, see Function Call Order.
classdef fprMetric < deep.Metric properties % (Required) Metric name. Name end properties (Access = private) % Define true negatives (TNs) and false positives (FPs). TrueNegatives FalsePositives end methods function metric = fprMetric(args) % metric = fprMetric creates an fprMetric metric object. % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0) % also specifies the optional Name option. By default, % the metric name is "FPR". By default, % the NetworkOutput is [], which corresponds to using all of % the network outputs. Maximize is set to 0 as the optimal value % occurs when the FPR value is minimized. arguments args.Name = "FPR" args.NetworkOutput = [] args.Maximize = false end % Set the metric name value. metric.Name = args.Name; % To support this metric for use with multi-output networks, set % the network output. metric.NetworkOutput = args.NetworkOutput; % To support this metric for early stopping and returning the % best network, set the maximize property. metric.Maximize = args.Maximize; end function metric = reset(metric) % metric = reset(metric) resets the metric properties. metric.TrueNegatives = 0; metric.FalsePositives = 0; end function metric = update(metric,batchY,batchT) % metric = update(metric,batchY,batchT) updates the metric % properties. % Find the channel (class) dimension. cDim = finddim(batchY,"C"); % Find the maximum score, which corresponds to the predicted % class. Set the predicted class to 1 and all other classes to 0. batchY = batchY == max(batchY,[],cDim); % Find the TN and FP values for this batch. batchTrueNegatives = sum(~batchY & ~batchT, 2); batchFalsePositives = sum(batchY & ~batchT, 2); % Add the batch values to the running totals and update the metric % properties. metric.TrueNegatives = metric.TrueNegatives + batchTrueNegatives; metric.FalsePositives = metric.FalsePositives + batchFalsePositives; end function metric = aggregate(metric,metric2) % metric = aggregate(metric,metric2) aggregates the metric % properties across two instances of the metric. metric.TrueNegatives = metric.TrueNegatives + metric2.TrueNegatives; metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives; end function val = evaluate(metric) % val = evaluate(metric) uses the properties in metric to return the % evaluated metric value. % Extract TN and FP values. tn = metric.TrueNegatives; fp = metric.FalsePositives; % Compute the FPR value. val = mean(fp./(fp+tn+eps)); end end end
Use Custom Metric During Training
You can use a custom metric in the same way as any other metric in Deep Learning Toolbox™. This section shows how to create and train a network for digit classification and track the FPR value.
Unzip the digit sample data and create an image datastore. The imageDatastore function automatically labels the images based on folder names.
unzip("DigitsData.zip") imds = imageDatastore("DigitsData", ... IncludeSubfolders=true, ... LabelSource="foldernames");
Use a subset of the data as the validation set.
numTrainingFiles = 750; [imdsTrain,imdsVal] = splitEachLabel(imds,numTrainingFiles,"randomize"); layers = [ ... imageInputLayer([28 28 1]) convolution2dLayer(5,20) reluLayer maxPooling2dLayer(2,Stride=2) fullyConnectedLayer(10) softmaxLayer];
Create an fprMetric object.
metric = fprMetric(Name="FalsePositiveRate")metric = 
  fprMetric with properties:
             Name: "FalsePositiveRate"
    NetworkOutput: []
         Maximize: 0
Specify the FPR metric in the training options. To plot the metric during training, set Plots to "training-progress". To output the values during training, set Verbose to true. Return the network that achieves the best FPR value. 
options = trainingOptions("adam", ... MaxEpochs=5, ... Metrics=metric, ... ValidationData=imdsVal, ... ValidationFrequency=50, ... Verbose=true, ... Plots="training-progress", ... ObjectiveMetricName="FalsePositiveRate", ... OutputNetwork="best-validation");
Train the network using the trainnet function. The values for the training and validation sets appear in the plot. 
net = trainnet(imdsTrain,layers,"crossentropy",options);    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss    ValidationLoss    TrainingFalsePositiveRate    ValidationFalsePositiveRate
    _________    _____    ___________    _________    ____________    ______________    _________________________    ___________________________
            0        0       00:00:04        0.001                            13.488                                                     0.10018
            1        1       00:00:04        0.001          13.974                                        0.10322                               
           50        1       00:00:20        0.001          2.7424            2.7448                     0.037368                       0.038889
          100        2       00:00:28        0.001          1.2965            1.2235                     0.027008                       0.023333
          150        3       00:00:35        0.001         0.64661           0.80412                     0.013953                       0.017867
          200        4       00:00:42        0.001         0.18627           0.53273                     0.006153                       0.012311
          250        5       00:00:50        0.001         0.16763           0.49371                    0.0060146                       0.012267
          290        5       00:00:57        0.001         0.25976           0.39347                    0.0062093                      0.0098222
Training stopped: Max epochs completed

See Also
trainingOptions | trainnet | dlnetwork
