Main Content

Customize Output During Deep Learning Network Training

This example shows how to define an output function that runs at each iteration during training of deep learning neural networks. If you specify output functions by using the 'OutputFcn' name-value pair argument of trainingOptions, then trainNetwork calls these functions once before the start of training, after each training iteration, and once after training has finished. Each time the output functions are called, trainNetwork passes a structure containing information such as the current iteration number, loss, and accuracy. You can use output functions to display or plot progress information, or to stop training. To stop training early, make your output function return true. If any output function returns true, then training finishes and trainNetwork returns the latest network.

To stop training when the loss on the validation set stops decreasing, simply specify validation data and a validation patience using the 'ValidationData' and the 'ValidationPatience' name-value pair arguments of trainingOptions, respectively. The validation patience is the number of times that the loss on the validation set can be larger than or equal to the previously smallest loss before network training stops. You can add additional stopping criteria using output functions. This example shows how to create an output function that stops training when the classification accuracy on the validation data stops improving. The output function is defined at the end of the script.

Load the training data, which contains 5000 images of digits. Set aside 1000 of the images for network validation.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Construct a network to classify the digit image data.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Specify options for network training. To validate the network at regular intervals during training, specify validation data. Choose the 'ValidationFrequency' value so that the network is validated once per epoch.

To stop training when the classification accuracy on the validation set stops improving, specify stopIfAccuracyNotImproving as an output function. The second input argument of stopIfAccuracyNotImproving is the number of times that the accuracy on the validation set can be smaller than or equal to the previously highest accuracy before network training stops. Choose any large value for the maximum number of epochs to train. Training should not reach the final epoch because training stops automatically.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

Train the network. Training stops when the validation accuracy stops increasing.

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:05 |        7.81% |       12.70% |       2.7155 |       2.5169 |          0.0100 |
|       1 |          31 |       00:00:13 |       71.88% |       74.90% |       0.8807 |       0.8130 |          0.0100 |
|       2 |          62 |       00:00:22 |       86.72% |       88.00% |       0.3899 |       0.4436 |          0.0100 |
|       3 |          93 |       00:00:33 |       94.53% |       94.00% |       0.2224 |       0.2553 |          0.0100 |
|       4 |         124 |       00:00:43 |       95.31% |       96.80% |       0.1482 |       0.1762 |          0.0100 |
|       5 |         155 |       00:00:52 |       98.44% |       97.60% |       0.1007 |       0.1314 |          0.0100 |
|       6 |         186 |       00:01:00 |       99.22% |       97.80% |       0.0784 |       0.1136 |          0.0100 |
|       7 |         217 |       00:01:08 |      100.00% |       98.10% |       0.0559 |       0.0945 |          0.0100 |
|       8 |         248 |       00:01:15 |      100.00% |       98.00% |       0.0441 |       0.0859 |          0.0100 |
|       9 |         279 |       00:01:23 |      100.00% |       98.00% |       0.0344 |       0.0786 |          0.0100 |
|      10 |         310 |       00:01:30 |      100.00% |       98.50% |       0.0274 |       0.0678 |          0.0100 |
|      11 |         341 |       00:01:37 |      100.00% |       98.50% |       0.0240 |       0.0621 |          0.0100 |
|      12 |         372 |       00:01:42 |      100.00% |       98.70% |       0.0213 |       0.0569 |          0.0100 |
|      13 |         403 |       00:01:49 |      100.00% |       98.80% |       0.0187 |       0.0534 |          0.0100 |
|      14 |         434 |       00:01:55 |      100.00% |       98.80% |       0.0164 |       0.0508 |          0.0100 |
|      15 |         465 |       00:02:03 |      100.00% |       98.90% |       0.0144 |       0.0487 |          0.0100 |
|      16 |         496 |       00:02:10 |      100.00% |       99.00% |       0.0126 |       0.0462 |          0.0100 |
|      17 |         527 |       00:02:17 |      100.00% |       98.90% |       0.0112 |       0.0440 |          0.0100 |
|      18 |         558 |       00:02:23 |      100.00% |       98.90% |       0.0101 |       0.0420 |          0.0100 |
|      19 |         589 |       00:02:29 |      100.00% |       99.10% |       0.0092 |       0.0405 |          0.0100 |
|      20 |         620 |       00:02:36 |      100.00% |       99.00% |       0.0086 |       0.0391 |          0.0100 |
|      21 |         651 |       00:02:42 |      100.00% |       99.00% |       0.0080 |       0.0380 |          0.0100 |
|      22 |         682 |       00:02:49 |      100.00% |       99.00% |       0.0076 |       0.0369 |          0.0100 |
|======================================================================================================================|
Training finished: Stopped by OutputFcn.

Figure Training Progress (19-Aug-2023 11:53:55) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 9 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 9 objects of type patch, text, line.

Define Output Function

Define the output function stopIfAccuracyNotImproving(info,N), which stops network training if the best classification accuracy on the validation data does not improve for N network validations in a row. This criterion is similar to the built-in stopping criterion using the validation loss, except that it applies to the classification accuracy instead of the loss.

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;

elseif ~isempty(info.ValidationAccuracy)

    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end

    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end

end

end

See Also

|

Related Topics