Gradients not recorded for a dlnetwork VAE
2 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
I have created a VAE using dlnetwork - an encoder, decoder and a classifier. The loss function will not record gradients for the update via dlgradients. Can you have a look and discover what is the cause.
% === Setup ===
numEpochs = 5;
miniBatchSize = 32;
learnRate = 1e-3;
latentDim = 32;
hiddenUnits = 64;
inputSize = size(XTrain{1}, 1); % Number of features
sequenceLength = size(XTrain{1}, 2); % Time steps
% === Encoder ===
encoderLayers = [
sequenceInputLayer(inputSize, 'Name', 'input')
lstmLayer(64, 'OutputMode', 'last', 'Name', 'lstm_enc')
fullyConnectedLayer(latentDim, 'Name', 'fc_mu')
fullyConnectedLayer(latentDim, 'Name', 'fc_logvar')
];
encoderNet = dlnetwork(layerGraph(encoderLayers));
% === Decoder ===
decoderLayers = [
sequenceInputLayer(latentDim, 'Name', 'latent_input')
fullyConnectedLayer(64, 'Name', 'fc_latent')
lstmLayer(64, 'OutputMode', 'sequence', 'Name', 'lstm_dec')
fullyConnectedLayer(inputSize, 'Name', 'fc_recon')
];
decoderNet = dlnetwork(layerGraph(decoderLayers));
% === Classifier ===
classifierLayers = [
featureInputLayer(latentDim, 'Name', 'class_input')
fullyConnectedLayer(hiddenUnits, 'Name', 'fc_class_hidden')
reluLayer('Name', 'relu_class')
fullyConnectedLayer(numClasses, 'Name', 'fc_class')
softmaxLayer('Name', 'softmax_out')
];
classifierNet = dlnetwork(layerGraph(classifierLayers));
function loss = computeTotalLoss(dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength)
% All operations inside this function are traced
[mu, logvar] = encodeLatents(dlX, encoderNet);
z = sampleLatents(mu, logvar); % Sampling with reparameterization
loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength);
end
function [mu, logvar] = encodeLatents(dlX, encoderNet)
mu = forward(encoderNet, dlX, 'Outputs', 'fc_mu');
logvar = forward(encoderNet, dlX, 'Outputs', 'fc_logvar');
end
function z = sampleLatents(mu, logvar)
eps = dlarray(randn(size(mu), 'like', mu)); % Traced random noise
z = mu + exp(0.5 * logvar) .* eps;
end
function loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength)
zRepeated = repmat(z, 1, 1, sequenceLength);
dlZSeq = dlarray(zRepeated, 'CBT');
reconOut = forward(decoderNet, dlZSeq, 'Outputs', 'fc_recon');
classOut = forward(classifierNet, dlarray(z, 'CB'), 'Outputs', 'softmax_out');
reconLoss = mse(reconOut, dlX);
classLoss = crossentropy(classOut, dlY);
klLoss = -0.5 * sum(1 + logvar - mu.^2 - exp(logvar), 'all');
loss = reconLoss + classLoss + klLoss;
end
for epoch = 1:numEpochs
idx = randperm(numel(XTrain));
totalEpochLoss = 0;
numBatches = 0;
for i = 1:miniBatchSize:numel(XTrain)
batchIdx = idx(i:min(i+miniBatchSize-1, numel(XTrain)));
XBatch = XTrain(batchIdx);
YBatch = YTrain(batchIdx, :);
% === Format Inputs ===
XCat = cat(3, XBatch{:});
XCat = permute(XCat, [1, 3, 2]);
dlX = dlarray(XCat, 'CBT');
dlY = dlarray(YBatch', 'CB');
totalLoss = dlfeval(@computeTotalLoss, dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength);
gradients = dlgradient(totalLoss, encoderNet.Learnables);
Error using dlarray/dlgradient (line 105)
Value to differentiate is not traced. It must be a traced real dlarray scalar. Use dlgradient inside a function called by dlfeval to trace the variables.
0 Commenti
Risposta accettata
Più risposte (1)
Matt J
il 22 Ago 2025
Modificato: Matt J
il 22 Ago 2025
As the error message says, the call to dlgradient must occur within the function (in this case computeTotalLoss) called by dlfeval, but that is not what you've done.
More confusingly, it appears that computeTotalLoss and computeLoss call each other in a recursive loop. That will create problems as well, I imagine. You are not meant to be calling dlfeval inside the loss function. It is meant to be called externally, in your training loop.
2 Commenti
Torsten
il 23 Ago 2025
You have to put "dlgradient" inside "computeTotalLoss". Then you can call "dlfeval" from somewhere outside "computeTotalLoss" and "computeLoss".
Vedere anche
Categorie
Scopri di più su Custom Training Loops in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!