# Customerized loss function taking X as inputs in CNN

1 visualizzazione (ultimi 30 giorni)
Yichen Wu il 13 Lug 2021
Risposto: Hrishikesh Borate il 16 Lug 2021
Hello,
I am following this example ( https://www.mathworks.com/help/deeplearning/ug/train-network-using-custom-training-loop.html ) to define a customerized training loop in a regression problem, because I want to pass my X data in each mini batch to the my customerized loss function.
As in the example in the link provided above, I wrote the modelGradients() inside which the loss function myLoss() I defined was called, instead of the crossentropy(). However, I recived the error:
Error using dlfeval (line 43)
Value to differentiate must be a traced dlarray scalar.
when tring to use the dlgradient() function.
Below is my code:
dlYPred = forward(dlnet,dlX);
loss = myLoss(dlX,dlYPred,Y);
end
function average_loss= myLoss(dlX,dlYPred,Y)
Ypred = extractdata(dlYPred);
X = extractdata(dlX);
sum_loss = 0;
% loop thru all the data in the mini batch and calculate the average
% loss
for i = 1:size(dlX,4)
% f_calculate_loss() is a self-defined loss function takes X,
% Ytarget and Ypred as input
sum_loss = sum_loss + f_calculate_loss(X(:,:,:,i),Y(i,:),Ypred(:,i));
end
% calculate the average loss and need to convert the type into dlarray
average_loss = sum_loss/size(dlX,4); % 1*1 dlarray
average_loss=dlarray(average_loss);
end
In my code, dlX is a 4-D 3(S)*72(S)*1(C)*64(B) single dlarray, dlYPred is 3*64 dlarray, Y is a 64*3 double array, where 64 is the miniBatchSize.The calculated loss (none zeor) is a 1*1 dlarray.
I've been stuck at this issue for weeks and still cound't understand what's going on. I really appreciate if anyone could help explaining what's going on here and how should I fix this. Thank you so much in advanced!
##### 0 CommentiMostra -2 commenti meno recentiNascondi -2 commenti meno recenti

Accedi per commentare.

### Risposte (1)

Hrishikesh Borate il 16 Lug 2021
Hi,
The problem arises due to the use of the extractdata before computing the gradient, as it breaks the derivative trace. Use the dlarray supported functions to compute the loss.
##### 0 CommentiMostra -2 commenti meno recentiNascondi -2 commenti meno recenti

Accedi per commentare.

### Categorie

Scopri di più su Custom Training Loops in Help Center e File Exchange

### Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by