Can you plot the gradient for CNNs using trainNetwork?

2 visualizzazioni (ultimi 30 giorni)
I am using the trainNetwork command to train my network, but noticed that there is no way to plot the gradients over iterations. The trainInfo output does contain some information, but does not seem to contain any information about the gradient.

Risposte (1)

Snehal
Snehal il 27 Mar 2025
I understand that you want to extract the gradient information while training a CNN and plot this over iterations. While ‘trainNetwork’ function in MATLAB does not directly expose gradients during the training process, there are two possible workarounds that you can follow:
  • Below is a sample code snippet on extracting gradients using ‘dlgradient’:
net = dlnetwork(layers); % Where ‘layers’ refers to a sequence of layers defined previously in the code.
% Assume 'net', 'XBatch', and 'YBatch' are already defined and 'XBatch' is a dlarray
% Forward pass
YPred = forward(net, XBatch);
% Computing loss
loss = crossentropy(YPred, YBatch);
% Compute gradients
gradients = dlgradient(loss, net.Learnables); % 'gradients' now contains the gradients of the loss with respect to the learnable parameters
  • To plot gradients when using ‘trainNetwork’, you can use a custom plot function instead. Information relating to rate of change of parameters like ‘TrainingLoss’and ‘ValidationLoss’ over iterations can be used to monitor and estimate the gradient-related patterns during training.
Below are some documentation links, you can refer to them for more information:
Hope this helps.

Categorie

Scopri di più su Image Data Workflows in Help Center e File Exchange

Prodotti


Release

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by