Customerized loss function taking X as inputs in CNN

5 views (last 30 days)
Hello,
I am following this example ( https://www.mathworks.com/help/deeplearning/ug/train-network-using-custom-training-loop.html ) to define a customerized training loop in a regression problem, because I want to pass my X data in each mini batch to the my customerized loss function.
As in the example in the link provided above, I wrote the modelGradients() inside which the loss function myLoss() I defined was called, instead of the crossentropy(). However, I recived the error:
Error using dlfeval (line 43)
Value to differentiate must be a traced dlarray scalar.
when tring to use the dlgradient() function.
Below is my code:
function [gradients,average_loss] = modelGradients(dlnet,dlX,Y)
dlYPred = forward(dlnet,dlX);
loss = myLoss(dlX,dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);
end
function average_loss= myLoss(dlX,dlYPred,Y)
Ypred = extractdata(dlYPred);
X = extractdata(dlX);
sum_loss = 0;
% loop thru all the data in the mini batch and calculate the average
% loss
for i = 1:size(dlX,4)
% f_calculate_loss() is a self-defined loss function takes X,
% Ytarget and Ypred as input
sum_loss = sum_loss + f_calculate_loss(X(:,:,:,i),Y(i,:),Ypred(:,i));
end
% calculate the average loss and need to convert the type into dlarray
average_loss = sum_loss/size(dlX,4); % 1*1 dlarray
average_loss=dlarray(average_loss);
end
In my code, dlX is a 4-D 3(S)*72(S)*1(C)*64(B) single dlarray, dlYPred is 3*64 dlarray, Y is a 64*3 double array, where 64 is the miniBatchSize.The calculated loss (none zeor) is a 1*1 dlarray.
I've been stuck at this issue for weeks and still cound't understand what's going on. I really appreciate if anyone could help explaining what's going on here and how should I fix this. Thank you so much in advanced!

Answers (1)

Hrishikesh Borate
Hrishikesh Borate on 16 Jul 2021
Hi,
The problem arises due to the use of the extractdata before computing the gradient, as it breaks the derivative trace. Use the dlarray supported functions to compute the loss.
For more information, refer Derivative Trace.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by