How to use Nadam optimizer in training deep neural networks
6 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
Training_Options = trainingOptions('sgdm', ...
'MiniBatchSize', 32, ...
'MaxEpochs', 50, ...
"InitialLearnRate", 1e-5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', Resized_Validation_Data, ...
'ValidationFrequency', 40, ...
"ExecutionEnvironment","gpu",...
'Plots','training-progress', ...
'Verbose',false);
Risposte (2)
Nayan
il 5 Apr 2023
Hi
I assume you want to use "adam" optimizer in place "sgdm". You need to simply replace the "sgdm" key with "adam" keyword.
options = trainingOptions("adam", ...
InitialLearnRate=3e-4, ...
SquaredGradientDecayFactor=0.99, ...
MaxEpochs=20, ...
MiniBatchSize=64, ...
Plots="training-progress")
0 Commenti
Amanjit Dulai
il 25 Ott 2024
You can train with Nadam by defining a custom training loop. The function dlupdate can be used to define custom update rules for training. The rules for Nadam are shown below:
where the momentum is given by:
Below is an example of how to train a digit classification network using Nadam in a custom training loop:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
1 Commento
Amanjit Dulai
il 25 Ott 2024
Also, if you want to use weight decay only on the weights, you can modify the example as shown below:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply weight regulatization
gradients(l2Indices,:) = dlupdate( @(g,w)g + l2RegularizationFactor*w, ...
gradients(l2Indices,:), net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
One thing to note is that with adaptive learning rules like Adam and Nadam, it has been found that it is often more effective to apply weight decay directly to the weights instead of the gradients. When applying this to Nadam, it results in the algorithm NadamW. Below is an example on how to use NadamW.
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply decoupled weight regulatization (NadamW)
net.Learnables(l2Indices,:) = dlupdate( @(w)w - learnRate*l2RegularizationFactor*w, ...
net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
Vedere anche
Categorie
Scopri di più su Deep Learning Toolbox in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!