How to plot the loss function on the overall dataset in Training Progress

7 visualizzazioni (ultimi 30 giorni)
I am writing a Convolutional Neural Network for regression in MATLAB R2021b. I'm using the trainNetwork function in Deep Learning Toolbox and in options I have 'Plots','training-progress'.
I have understood that in the Training Progress, for each iteration, I have the MSE value computed on the mini-batch.
I wuold like to know whether I can plot in the training progress the MSE value on the overall Training Set and Validation Set respectively and not on the mini-batch.
Thanks in advance.

Risposte (1)

Aneela
Aneela il 13 Set 2024
Modificato: Aneela il 22 Set 2024
Hi Maria,
The trainNetwork function's default training progress plot displays the mini-batch loss (such as MSE) and accuracy for each iteration.
  • However, it does not directly provide options to display metrics computed over the entire training or validation set during training.
  • To achieve this, the training loop should be customised using a custom training loop approach.
Here’s a possible workaround:
  • Set hyperparameters like learning rate, number of epochs, and mini-batch size.
  • Iterate over number of epochs.
  • Within each epoch, iterate over mini-batches and compute predictions for mini-batch.
  • Compute the gradients of the loss with respect to model parameters.
  • Improve the model's performance by minimizing the loss using an optimization algorithm.
  • Compute the MSE over the training and validation datasets after each mini-batch update. Here’s a sample code snippet:
%net -Network, (XTrain,YTrain)-Training data,
% (XValidation, YValidation)-Validation data
YPredTrain = predict(net, XTrain);
trainMSE = mean((YPredTrain - YTrain).^2);
YPredValidation = predict(net, XValidation);
validationMSE = mean((YPredValidation - YValidation).^2);
  • Plot the MSE for both training and validation datasets throughout the training process using “addpoints and “drawnow”.
trainingPlot = animatedline('Color','r');
validationPlot = animatedline('Color','b');
addpoints(trainingPlot, iteration, trainMSE);
addpoints(validationPlot, iteration, validationMSE);
drawnow;
Refer to the following MathWorks documentation for more information on:

Categorie

Scopri di più su Deep Learning Toolbox in Help Center e File Exchange

Prodotti


Release

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by