Main Content

Define Custom Metric Function

Note

This topic explains how to define a custom metric function for your task. Use a custom metric function if Deep Learning Toolbox™ does not support the metric you need. For a list of built-in metrics in Deep Learning Toolbox, see Metrics. If a built-in MATLAB® function satisfies the required syntax, then you can use that function instead. For example, you can use the built-in l1loss function to find the L1 loss. For information about the required function syntax, see Create Custom Metric Function.

In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.

If Deep Learning Toolbox does not provide the metric that you need for your task, then in many cases you can create a custom metric using a function. After you define the metric function, you can specify the metric as the Metrics name-value argument in the trainingOptions function.

 How To Decide Which Metric Type To Use

Create Custom Metric Function

To create a custom metric function, you can use this template.

function val = myMetricFunction(Y,T)
% Evaluate custom metric.

% Inputs:
%          Y - Formatted dlarray of predictions
%          T - Formatted dlarray of targets
%
% Outputs:
%           val - Metric value
%
% Define the metric function here.
end

The function takes as input a formatted dlarray object of network predictions Y and a formatted dlarray object of network targets T. The function must return a single numeric output corresponding to the metric value.

Depending on your metric, you sometimes need to know the dimension labels before computing the metric. Use the finddim function to find dimensions with a specific label. For example, to average your metric across batches, you need to know the batch dimension.

Note

When you have validation data in mini-batches, the software computes the validation metric for each mini-batch and then returns the average of those values. For some metrics, this behavior can result in a different metric value than if you compute the metric using the whole validation set at once. In most cases, the values are similar. To use a custom metric that is not batch-averaged for the validation data, you must create a custom metric object. For more information, see Define Custom Deep Learning Metric Object.

To use the metric during training, specify the function handle as the Metrics option of the trainingOptions function.

trainingOptions("sgdm", ...
    Metrics=@myMetricFunction)

Example Regression Metric

For regression tasks, the function must accept a formatted dlarray object of predictions and targets.

This code shows an example of a regression metric. This custom metric function computes the symmetric mean absolute percentage error (SMAPE) value given predictions and targets. This equation defines the SMAPE value:

SMAPE=100ni=1n|YiTi|(|Ti|+|Yi|)/2,

where Y are the network predictions and T are the target responses.

function val = SMAPE(Y,T)
% Compute SMAPE value.

absoluteDifference = abs(Y-T);
absoluteAvg = (abs(Y) + abs(T))./2;
proportion = absoluteDifference./absoluteAvg;

val = 100*mean(proportion,"all");
end

Example Classification Metric

For classification tasks, the function must accept a formatted dlarray object of predictions and targets encoded as one-hot vectors. Each column in the vector represents a class and each row represents an observation. For example, this code defines a one-hot vector. For more information, see the onehotencode function.

Y =
     0     0     1
     1     0     0
     0     0     1
     0     1     0 

This code shows an example of a classification metric. This custom metric function computes the macro-averaged error rate value given predictions and targets. This equation defines the macro error rate:

errorRate(macro)=1Ki=1KFPi+FNiTPi+TNi+FPi+FNi,

where TPi, TNi, FPi, and FNi represent the number of true positives, true negatives, false positives, and false negatives, respectively, in class i and K is the number of classes.

function val = errorRate(Y,T)
% Compute macro error rate value.

% Find the channel (class) dimension.
cDim = finddim(Y,"C");
bDim = finddim(Y,"B");

% Find the maximum score. This corresponds to the predicted
% class. Set the predicted class as 1 and all other classes as 0.
Y = Y == max(Y,[],cDim);

% Find the TP, FP, FN for this batch.
TP  = sum(Y & T, bDim);
FP = sum(Y & ~T, bDim);
FN = sum(~Y & T, bDim);
TN = sum(~Y & ~T, bDim);

% Compute the error rate value and average across each class.
val = mean((FP + FN)./(TP + TN + FP + FN));
end

Tip

If your metric has a fraction whose denominator value can be zero, you can add eps to the denominator to prevent the metric returning a NaN value.

See Also

| |

Related Topics