How does dlgradient handle two output layers?

8 visualizzazioni (ultimi 30 giorni)
I have a deep learning image analysis network with a custom training loop and loss function. In this loss function I take the output layer's output and the second to last layer's output from the net, perform mse (using the inbuilt function) on both, and sum the results. My output layer is a custom layer, so I have control over it's backwards function, but I cannot see the automatic backwards in the other layers.
When I use dlgradient on this new loss, how does Matlab handle two simultaneous rounds of differentiation? Does it sum the gradients for each learnable, or does one gradient calculation override the other?
I have tried both adam and SGDM methods of updating gradients, but I am suspecting since my network is not learning as expected that I need to adjust my backwards function in my custom layer to ensure a proper gradient backpropagation.
EDIT: Added some abstract code to demonstrate the context of my question. Specifically I'm asking if anyone knows what's going on under the hood as it does two backpropagations starting from two places. How it handles combining the propagated backwards gradients as it performs dlgradient, so I can better understand how to craft my backwards function.
function dlgradientForumQuestion()
% create a network
modelLayers = [
imageInputLayer([imageSize,imageSize],"Name","image_input")
convolution2dLayer([3,3],numFilters,"Name","processing")
customLayer("Name","output_layer")
];
% initialize network
myNetwork = dlnetwork(modelLayers);
myNetwork = initialize(myNetwork);
% do training
for i = 1:numTrainingIterations
% use custom loss function
[gradients,newState] = dlfeval(@customLossFunction,myNetwork,inputData,targetData1,targetData2);
myNetwork.State = newState;
% do the adamupdate here
[myNetwork,~,~] = adamupdate(myNetwork,gradients,averageGrad,averageSqGrad,iteration);
end
end
function [gradients,newState] = customLossFunction(myNetwork,inputData,targetData1,targetData2)
% forward through network
[customOutput,secondToLastOutput,newState] = forward(myNetwork,inputData,"Outputs",["output_layer","processing"]);
% do loss mse
loss1 = mse(customOutput,targetData1);
loss2 = mse(secondToLastOutput,targetData2);
% sum the losses
loss = loss1 + loss2;
% get gradients
gradients = dlgradient(loss,myNetwork.Learnables);
end
  2 Commenti
Catalytic
Catalytic il 23 Ago 2024
So many words. So little code.
Violet
Violet il 23 Ago 2024
I have added some abstract code to demonstrate the context of the question.

Accedi per commentare.

Risposta accettata

Matt J
Matt J il 23 Ago 2024
Modificato: Matt J il 23 Ago 2024
Along the lines of what @Catalytic said, it is hard to envision how you are trying to use dlgradient without seeing that code. As a general point, however, there is no substantive difference in how backpropagation works when your network has forks in it. The fact that you have two outputs introduces a fork in the network graph just the same as if you had a network with skip connections. Backpropagation handles both the same way and hence dlgradient, which merely invokes backpropagation, works the same way as well.
  4 Commenti
Violet
Violet il 23 Ago 2024
That makes sense. loss2 doesn't touch the output layer though, as far as propagated gradient goes (from my understanding of writing a custom backwards function). So I guess I don't understand how it does automatic differentiation on the learnables from two different error functions, or how it combines them.
That being said, perhaps there is no magic behind the scenes things going on. Thank you for taking the time to answer.
Matt J
Matt J il 23 Ago 2024
You are quite welcome, but if/when you consider your question resolved, please Accept-click the answer.

Accedi per commentare.

Più risposte (0)

Categorie

Scopri di più su Sequence and Numeric Feature Data Workflows in Help Center e File Exchange

Prodotti


Release

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by