Main Content

Define Custom Deep Learning Metric Object

Note

This topic explains how to define custom deep learning metric objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics. You can also specify custom metrics using a function handle. For more information, see Define Custom Metric Function.

In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.

 How To Decide Which Metric Type To Use

If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a function handle, then you can define your own custom metric object using this topic as a guide. After you define the custom metric, you can specify the metric as the Metrics name-value argument in the trainingOptions function.

Metric Template

To define a custom metric, use this class definition template as a starting point. For an example that shows how to use this template to create a custom metric, see Define Custom F-Beta Score Metric Object.

The template outlines how to specify these aspects of the class definition:

  • The properties block for public metric properties. This block must contain the Name property.

  • The properties block for private metric properties. This block is optional.

  • The metric constructor function.

  • The optional initialize function.

  • The required reset, update, aggregate, and evaluate functions.

For information about when the software calls each function, see Function Call Order.

classdef myMetric < deep.Metric

    properties
        % (Required) Metric name.
        Name

        % Declare public metric properties here.

        % Any code can access these properties. Include here any properties
        % that you want to access or edit outside of the class.
    end

    properties (Access = private)
        % (Optional) Metric properties.

        % Declare private metric properties here.

        % Only members of the defining class can access these properties.
        % Include here properties that you do not want to edit outside
        % the class.
    end

    methods
        function metric = myMetric(args)
            % Create a myMetric object.
            % This function must have the same name as the class.

            % Define metric construction function here.
        end

        function metric = initialize(metric,batchY,batchT)
            % (Optional) Initialize metric.
            %
            % Use this function to initialize variables and run validation
            % checks.
            %
            % Inputs:
            %           metric - Metric to initialize
            %           batchY - Mini-batch of predictions
            %           batchT - Mini-batch of targets
            %
            % Output:
            %           metric - Initialized metric
            %
            % For networks with multiple outputs, replace batchY with
            % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
            % where N is the number of network outputs. To create a metric
            % that supports any number of network outputs, replace batchY
            % and batchT with varargin.

            % Define metric initialization function here.
        end

        function metric = reset(metric)
            % Reset metric properties.
            %
            % Use this function to reset the metric properties between
            % iterations.
            %
            % Input:
            %           metric - Metric containing properties to reset
            %
            % Output:
            %           metric - Metric with reset properties

            % Define metric reset function here.
        end

        function metric = update(metric,batchY,batchT)
            % Update metric properties.
            %
            % Use this function to update metric properties that you use to
            % compute the final metric value.
            %
            % Inputs:
            %           metric - Metric containing properties to update
            %           batchY - Mini-batch of predictions
            %           batchT - Mini-batch of targets
            %
            % Output:
            %           metric - Metric with updated properties
            %
            % For networks with multiple outputs, replace batchY with
            % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
            % where N is the number of network outputs. To create a metric
            % that supports any number of network outputs, replace batchY
            % and batchT with varargin.

            % Define metric update function here.
        end

        function metric = aggregate(metric,metric2)
            % Aggregate metric properties.
            %
            % Use this function to define how to aggregate properties from
            % multiple instances of the same metric object during parallel
            % training.
            %
            % Inputs:
            %           metric  - Metric containing properties to aggregate
            %           metric2 - Metric containing properties to aggregate
            %
            % Output:
            %           metric - Metric with aggregated properties
            %
            % Define metric aggregation function here.
        end

        function val = evaluate(metric)
            % Evaluate metric properties.
            %
            % Use this function to define how to use the metric properties
            % to compute the final metric value.
            %
            % Input:
            %           metric - Metric containing properties to use to
            %           evaluate the metric value
            %
            % Output:
            %           val - Evaluated metric value
            %
            % To return multiple metric values, replace val with val1,...
            % valN.

            % Define metric evaluation function here.
        end
    end
end

Metric Properties

Declare the metric properties in the property sections. You can specify attributes in the class definition to customize the behavior of properties for specific purposes. This template defines two property types by setting their Access attribute. Use the Access attribute to control access to specific class properties.

  • properties — Any code can access these properties. This is the default properties block with the default property attributes. By default, the Access attribute is public.

  • properties (Access = private) — Only members of the defining class can access the property.

Public Properties

Declare public metric properties in the properties section of the class definition. These properties have public access, which means any code can access the values. By default, custom metrics have the NetworkOutput public property with the default value []. The NetworkOutput property defines which network output to apply the metric to.

You must define the Name property in this block. The Name property controls the name of the metric in any plots or command line output.

Private Properties

Declare private metric properties in the properties (Access = private) section of the class definition. These properties have private access, which means only members of the defining class can access these properties. For example, the class functions can access private properties. If the metric has no private properties, then you can omit this properties section.

Constructor Function

The constructor function creates the metric and initializes the metric properties. The constructor function must take as input any variables that you need to compute the metric. This function must have the same name as the class.

To use any properties as name-value arguments, you must set them in the constructor function. All metrics require the optional Name argument.

Tip

To use the NetworkOutput property as a name-value argument, you must set the property in the constructor function.

Initialization Function

The initialize function is an optional function that the software calls after reading the first batch of data. You can use this function to initialize variables and run validation checks.

The initialize function must have this syntax, where batchY and batchT inputs represent the mini-batch predictions and targets, respectively. For networks with multiple outputs, replace batchY with batchY1,...,batchYN and batchT with batchT1,...,batchTN, where N is the number of network outputs. To create a metric that supports any number of network outputs, replace batchY and batchT with varargin.

metric = initialize(metric,batchY,batchT)

Example initialize Function

This code shows an example of an initialize function that checks that you are using the metric for a network with a single output and therefore only one set of batch predictions and targets.

        function metric = initialize(metric,batchY,batchT)
            if nargin ~= 3
                error("Metric not supported for networks with multiple outputs.")
            end
        end

Reset Function

The reset function resets the metric properties. The software calls this function before each iteration. For more information, see Function Call Order.

The reset function must have this syntax.

metric = reset(metric)

Update Function

The update function updates the metric properties that you use to compute the metric value. The function calls update during each training and validation mini-batch. For more information, see Function Call Order.

The update function must have this syntax, where batchY and batchT inputs represent the mini-batch predictions and targets, respectively. For networks with multiple outputs, replace batchY with batchY1,...,batchYN and batchT with batchT1,...,batchTN, where N is the number of network outputs. To create a metric that supports any number of network outputs, replace batchY and batchT with varargin.

metric = update(metric,batchY,batchT)

Aggregation Function

The aggregate function specifies how to combine properties from multiple instances of the same metric object during parallel training. When you train a network in parallel, the software divides each training mini-batch into smaller subsets. For each subset, the software then calls update to update the metric properties, and then calls aggregate to consolidate the results for the whole mini-batch. For more information, see Function Call Order.

The aggregate function must have this syntax, where metric2 input is another instance of the metric. To ensure that your function always produces the same results, make sure that aggregate is an associative function.

metric = aggregate(metric,metric2)

Evaluation Function

The evaluate function specifies how to compute the metric value. In most cases, the final metric value is a function of the metric properties.

For the training data, the software calls evaluate at the end of each mini-batch. For the validation data, the software calls evaluate after all of the data passes through the network. Therefore, the software computes the metric for each batch of training data but for all of the validation data. For more information, see Function Call Order.

The evaluate function must have this syntax, where M is the number of metrics to return.

[val,...,valM] = evaluate(metric)

Function Call Order

The order in which the software calls the initialize, reset, update, aggregate, and evaluate functions depends on where in the training loop the software is. The first function the software calls is initialize. The software calls initialize after it reads the first batch of data.

The order in which the software calls the remaining functions depends on whether the data is training or validation data.

  • Training data — For each mini-batch, the software calls reset, then update, and then evaluate. Therefore, the software returns the metric value for each training mini-batch, where each batch is equivalent to a single training iteration.

  • Validation data — For each mini-batch, the software calls update only. The software calls evaluate after all of the validation data passes through the network. Therefore, the software returns the metric value for the whole validation set (full-batch). This behavior is equivalent to a validation iteration. The software calls reset before the first validation mini-batch.

This diagram illustrates the difference between how the software computes the metric for the training and validation data.

Note

When you train a network using the L-BFGS solver, the software processes all of the data in a single batch. This behavior is equivalent to a single mini-batch with all of the observations.

Aggregate Data

The aggregate function defines how to aggregate properties from multiple instances of the same metric object during parallel training. When you train a network in parallel, the software divides each training mini-batch into smaller subsets. For each subset, the software then calls update to update the metric properties, and then calls aggregate to consolidate the results for the whole mini-batch. Finally, the software calls evaluate to obtain the metric value for the whole training mini-batch.

See Also

| |

Related Topics