Custom Training with Multiple GPUs in Experiment Manager
This example shows how to configure multiple parallel workers to collaborate on each trial of a custom training experiment. In this example, parallel workers train on portions of the overall mini-batch in each trial of an image classification experiment. During training, a DataQueue
object sends training progress information back to Experiment Manager. If you have a supported GPU, then training happens on the GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox).
As an alternative, you can set up a parallel custom training loop that runs a single trial of this experiment programmatically. For more information, see Train Network in Parallel with Custom Training Loop.
Open Experiment
First, open the example. Experiment Manager loads a project with a preconfigured experiment that you can inspect and run. To open the experiment, in the Experiment Browser pane, double-click the name of the experiment (ParallelCustomLoopExperiment
).
Custom training experiments consist of a description, a table of hyperparameters, and a training function. For more information, see Configure Custom Training Experiment.
The Description field contains a textual description of the experiment. For this example, the description is:
Use multiple parallel workers to train an image classification network. Each trial uses a different initial learning rate and momentum.
The Hyperparameters section specifies the hyperparameter values to use for the experiment. When you run the experiment, Experiment Manager trains the network using every combination of hyperparameter values specified in the hyperparameter table. This example uses two hyperparameters:
InitialLearnRate
sets the initial learning rate used for training. If the learning rate is too low, then training takes a long time. If the learning rate is too high, then training can reach a suboptimal result or diverge. The best learning rate depends on your data as well as the network you are training.Momentum
specifies the contribution of the gradient step from the previous iteration to the current iteration of stochastic gradient descent with momentum.
The Training Function specifies the training data, network architecture, training options, and training procedure used by the experiment. The input to the training function is a structure with fields from the hyperparameter table and an experiments.Monitor
object that you can use to track the progress of the training, record values of the metrics used by the training, and produce training plots. The training function returns a structure that contains the trained network, the training loss, and the validation accuracy. Experiment Manager saves this output, so you can export it to the MATLAB workspace when the training is complete. The training function has six sections.
Initialize Output sets the initial value of the network, training loss, and validation accuracy to empty arrays to indicate that the training has not started.
output.network = []; output.loss = []; output.accuracy = [];
Load Training and Test Data defines the training and test data for the experiment as
imageDatastore
objects. The experiment uses the Digits data set, which consists of 5000 28-by-28 pixel grayscale images of digits from 0 to 9, categorized by the digit they represent. For more information on this data set, see Image Data Sets.
monitor.Status = "Loading Data";
digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ... "nndatasets","DigitDataset"); imds = imageDatastore(digitDatasetPath, ... IncludeSubfolders=true, ... LabelSource="foldernames");
[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; YTest = imdsTest.Labels;
Define Network Architecture defines the architecture for the image classification network. This network architecture includes batch normalization layers that track the mean and variance statistics of the data set. When training in parallel, to ensure the network state reflects the whole mini-batch, combine the statistics from all of the workers at the end of each iteration step. Otherwise, the network state can diverge across the workers. If you are training stateful recurrent neural networks (RNNs), for example, using sequence data that has been split into smaller sequences to train networks containing LSTM or GRU layers, you must also manage the state between the workers. To train the network with a custom training loop and enable automatic differentiation, the training function converts the layer graph to a
dlnetwork
object.
monitor.Status = "Creating Network";
layers = [
imageInputLayer([28 28 1],Normalization="none")
convolution2dLayer(5,20)
batchNormalizationLayer
reluLayer
convolution2dLayer(3,20,Padding=1)
batchNormalizationLayer
reluLayer
convolution2dLayer(3,20,Padding=1
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)];
lgraph = layerGraph(layers); net = dlnetwork(lgraph);
Set Up Parallel Environment determines if GPUs are available for MATLAB to use. If there are GPUs available, then train on the GPUs. If no parallel pool exists, create one with as many workers as GPUs. If there are no GPUs available, then train on the CPUs. If no parallel pool exists, create one with the default number of workers.
monitor.Status = "Starting Parallel Pool";
pool = gcp("nocreate");
if canUseGPU executionEnvironment = "gpu"; if isempty(pool) numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); end else executionEnvironment = "cpu"; if isempty(pool) pool = parpool; end end
N = pool.NumWorkers;
Specify Training Options defines the training options used by the experiment. In this example, Experiment Manager trains the network with a mini-batch size of
128
for20
epochs using the initial learning rate and momentum defined in the hyperparameter table. If you are training on a GPU, the mini-batch size scales up linearly with the number of GPUs to keep the workload on each GPU constant. For more information, see Deep Learning with MATLAB on Multiple GPUs.
numEpochs = 20; miniBatchSize = 128; velocity = []; initialLearnRate = params.InitialLearnRate; momentum = params.Momentum; decay = 0.01;
if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* N; end
workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)];
Train Model defines the parallel custom training loop used by the experiment. To execute the code simultaneously on all the workers, the training function uses an
spmd
block that cannot containbreak
,continue
, orreturn
statements. As a result, you cannot interrupt a trial of the experiment while training is in progress. If you press Stop, Experiment Manager runs the current trial to completion before stopping the experiment. For more information on the parallel custom training loop, see Appendix 1 at the end of this example.
monitor.Metrics = ["TrainingLoss" "ValidationAccuracy"]; monitor.XLabel = "Iteration"; monitor.Status = "Training";
Q = parallel.pool.DataQueue; updateFcn = @(x) updateTrainingProgress(x,monitor); afterEach(Q,updateFcn);
spmd workerImds = partition(imdsTrain,N,labindex); workerImds.ReadSize = workerMiniBatchSize(labindex); workerVelocity = velocity; iteration = 0; lossArray = []; accuracyArray = []; for epoch = 1:numEpochs reset(workerImds); workerImds = shuffle(workerImds); if ~monitor.Stop while gop(@and,hasdata(workerImds)) iteration = iteration + 1; [workerXBatch,workerTBatch] = read(workerImds); workerXBatch = cat(4,workerXBatch{:}); workerNumObservations = numel(workerTBatch.Label); workerXBatch = single(workerXBatch) ./ 255; workerY = zeros(numClasses,workerNumObservations,"single"); for c = 1:numClasses workerY(c,workerTBatch.Label==classes(c)) = 1; end workerX = dlarray(workerXBatch,"SSCB"); if executionEnvironment == "gpu" workerX = gpuArray(workerX); end [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerY); workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize; loss = gplus(workerNormalizationFactor*extractdata(workerLoss)); net.State = aggregateState(workerState,workerNormalizationFactor); workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor}); learnRate = initialLearnRate/(1 + decay*iteration); [net.Learnables,workerVelocity] = sgdmupdate(net.Learnables,workerGradients,workerVelocity,learnRate,momentum); end lossArray = [lossArray; iteration, loss]; if labindex == 1 YPredScores = predict(net,dlarray(XTest,"SSCB")); [~,idx] = max(YPredScores,[],1); YPred = classes(idx); accuracy = mean(YPred==YTest); lossArray = [lossArray; iteration, loss]; accuracyArray = [accuracyArray; iteration, accuracy]; data = [numEpochs epoch iteration loss accuracy]; send(Q,gather(data)); end end end end
To inspect the training function, under Training Function, click Edit. The training function opens in MATLAB® Editor. In addition, the code for the training function appears in Appendix 1 at the end of this example.
Run Experiment
When you run the experiment, Experiment Manager trains the network defined by the training function multiple times. Each trial uses a different combination of hyperparameter values.
Because this experiment uses the parallel pool for this MATLAB session, you cannot train multiple trials at the same time. On the Experiment Manager toolstrip, under Mode, select Sequential
and click Run. Alternatively, to offload the experiment as a batch job, set Mode to Batch Sequential
, specify your Cluster and Pool Size, and click Run. For more information, see Offload Experiments as Batch Jobs to Cluster.
A table of results displays the training loss and validation accuracy for each trial.
To display the training plot and track the progress of each trial while the experiment is running, under Review Results, click Training Plot.
Note that the training function for this experiment uses an spmd
statement, which cannot contain break
, continue
, or return
statements. As a result, you cannot interrupt a trial of the experiment while training is in progress. If you click Stop, Experiment Manager runs the current trial to completion before stopping the experiment.
Evaluate Results
To find the best result for your experiment, sort the table of results by validation accuracy.
Point to the ValidationAccuracy column.
Click the triangle icon.
Select Sort in Descending Order.
The trial with the highest validation accuracy appears at the top of the results table.
To test the best trial in your experiment, plot a confusion matrix.
In the results table, select the trial with the highest validation accuracy.
On the Experiment Manager toolstrip, click Export > Training Output.
In the dialog window, enter the name of a workspace variable for the exported training output. The default name is
trainingOutput
.Create a confusion matrix by calling the
drawConfusionMatrix
function, which is listed in Appendix 2 at the end of this example. As the input to the function, use the exported training output and the fraction of the Digits data set to use as a test set. For instance, in the MATLAB Command Window, enter:
drawConfusionMatrix(trainingOutput,0.5)
The function creates a confusion matrix using half of the images in the data set.
To record observations about the results of your experiment, add an annotation.
In the results table, right-click the ValidationAccuracy cell of the best trial.
Select Add Annotation.
In the Annotations pane, enter your observations in the text box.
For more information, see Sort, Filter, and Annotate Experiment Results.
Close Experiment
In the Experiment Browser pane, right-click the name of the project and select Close Project. Experiment Manager closes all of the experiments and results contained in the project.
Appendix 1: Training Function
This function configures the training data, network architecture, and training options for the experiment. To execute the code simultaneously on all the workers, the function uses an spmd
block. Within the spmd
block, labindex
gives the index of the worker currently executing the code. Before training, the function partitions the datastore for each worker by using the partition
function, and sets ReadSize
to the mini-batch size of the worker. For each epoch, the function resets and shuffles the datastore. For each iteration in the epoch, the function:
Reads a mini-batch from the datastore and process the data for training.
Computes the loss and the gradients of the network on each worker by calling
dlfeval
on themodelLoss
function.Obtains the overall loss using cross-entropy and aggregates the losses on all workers using the sum of all losses.
Aggregates and updates the gradients of all workers using the
dlupdate
function with theaggregateGradients
function.Aggregates the state of the network on all workers using the
aggregateState
function.Updates the network learnable parameters with the
sgdmupdate
function.
At the end of each epoch, the function uses only worker to send the training progress information back to the client.
Input
params
is a structure with fields from the Experiment Manager hyperparameter table.monitor
is anexperiments.Monitor
object that you can use to track the progress of the training, update information fields in the results table, record values of the metrics used by the training, and produce training plots.
Output
output
is a structure that contains the traineddlnetwork
object, the training loss array, and the validation accuracy array. Experiment Manager saves this output, so you can export it to the MATLAB workspace when the training is complete.
function output = ParallelCustomLoopExperiment_training(params,monitor) output.network = []; output.loss = []; output.accuracy = []; monitor.Status = "Loading Data"; digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ... "nndatasets","DigitDataset"); imds = imageDatastore(digitDatasetPath, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized"); classes = categories(imdsTrain.Labels); numClasses = numel(classes); XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; YTest = imdsTest.Labels; monitor.Status = "Creating Network"; layers = [ imageInputLayer([28 28 1],Normalization="none") convolution2dLayer(5,20) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding=1) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding=1) batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses)]; lgraph = layerGraph(layers); net = dlnetwork(lgraph); monitor.Status = "Starting Parallel Pool"; pool = gcp("nocreate"); if canUseGPU executionEnvironment = "gpu"; if isempty(pool) numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); end else executionEnvironment = "cpu"; if isempty(pool) pool = parpool; end end N = pool.NumWorkers; numEpochs = 20; miniBatchSize = 128; velocity = []; initialLearnRate = params.InitialLearnRate; momentum = params.Momentum; decay = 0.01; if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* N; end workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]; monitor.Metrics = ["TrainingLoss" "ValidationAccuracy"]; monitor.XLabel = "Iteration"; monitor.Status = "Training"; Q = parallel.pool.DataQueue; updateFcn = @(x) updateTrainingProgress(x,monitor); afterEach(Q,updateFcn); spmd workerImds = partition(imdsTrain,N,labindex); workerImds.ReadSize = workerMiniBatchSize(labindex); workerVelocity = velocity; iteration = 0; lossArray = []; accuracyArray = []; for epoch = 1:numEpochs reset(workerImds); workerImds = shuffle(workerImds); if ~monitor.Stop while gop(@and,hasdata(workerImds)) iteration = iteration + 1; [workerXBatch,workerTBatch] = read(workerImds); workerXBatch = cat(4,workerXBatch{:}); workerNumObservations = numel(workerTBatch.Label); workerXBatch = single(workerXBatch) ./ 255; workerY = zeros(numClasses,workerNumObservations,"single"); for c = 1:numClasses workerY(c,workerTBatch.Label==classes(c)) = 1; end workerX = dlarray(workerXBatch,"SSCB"); if executionEnvironment == "gpu" workerX = gpuArray(workerX); end [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerY); workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize; loss = gplus(workerNormalizationFactor*extractdata(workerLoss)); net.State = aggregateState(workerState,workerNormalizationFactor); workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor}); learnRate = initialLearnRate/(1 + decay*iteration); [net.Learnables,workerVelocity] = sgdmupdate(net.Learnables,workerGradients,workerVelocity,learnRate,momentum); end if labindex == 1 YPredScores = predict(net,dlarray(XTest,"SSCB")); [~,idx] = max(YPredScores,[],1); YPred = classes(idx); accuracy = mean(YPred==YTest); lossArray = [lossArray; iteration, loss]; accuracyArray = [accuracyArray; iteration, accuracy]; data = [numEpochs epoch iteration loss accuracy]; send(Q,gather(data)); end end end end output.network = net{1}; output.loss = lossArray{1}; output.accuracy = accuracyArray{1}; delete(gcp("nocreate")); end
Appendix 2: Custom Training Helper Functions
This function computes the loss and the gradients of the loss with respect to the learnable parameters of the network. This function computes the network outputs for a mini-batch X with forward and softmax
and calculates the loss, given the true outputs, using cross-entropy. When you call this function with dlfeval
, automatic differentiation is enabled, and dlgradient
can compute the loss with respect to the learnables automatically.
function [loss,gradients,state] = modelLoss(net,X,Y)
[YPred,state] = forward(net,X); YPred = softmax(YPred); loss = crossentropy(YPred,Y); gradients = dlgradient(loss,net.Learnables);
end
This function displays training progress information and updates metric values that come from the workers. The DataQueue
object in this example calls this function every time a worker sends data.
function updateTrainingProgress(data,monitor)
monitor.Progress = (data(2)/data(1))*100;
recordMetrics(monitor,data(4), ...
TrainingLoss=data(3));
end
This function aggregates the gradients on all workers by adding them together. gplus
adds together and replicates all the gradients on the workers. Before adding them together, the function normalizes them by multiplying them by a factor that represents the proportion of the overall mini-batch that the worker is working on.
function gradients = aggregateGradients(gradients,factor)
gradients = gplus(factor*gradients);
end
This function aggregates the network state on all workers. The network state contains the trained batch normalization statistics of the data set. Since each worker sees only a portion of the mini-batch, aggregate the network state so that the statistics are representative of the statistics across all the data. For each mini-batch, the combined mean is calculated as a weighted average of the mean across the workers for each iteration. The combined variance is calculated according to the formula:
is the total number of workers,
is the total number of observations in a mini-batch,
is the number of observations processed on the
th worker,
and
are the mean and variance statistics calculated on that worker, and
is the combined mean across all workers.
function state = aggregateState(state,factor) numrows = size(state,1); for j = 1:numrows isBatchNormalizationState = state.Parameter(j) =="TrainedMean"... && state.Parameter(j+1) =="TrainedVariance"... && state.Layer(j) == state.Layer(j+1); if isBatchNormalizationState meanVal = state.Value{j}; varVal = state.Value{j+1}; combinedMean = gplus(factor*meanVal); combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2); state.Value(j) = {combinedMean}; state.Value(j+1) = {gplus(combinedVarTerm)}; end end end
Appendix 3: Create Confusion Matrix
This function takes as input a trained network and the fraction of the Digits data set to use as a test set and creates a confusion matrix chart.
function drawConfusionMatrix(trainingOutput,testSize) dataFolder = fullfile(toolboxdir("nnet"), ... "nndemos","nndatasets","DigitDataset"); imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, .... LabelSource="foldernames"); [~,imdsTest] = splitEachLabel(imds,testSize,"randomized"); classes = categories(imdsTest.Labels); trainedNet = trainingOutput.network; XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; trueLabels = imdsTest.Labels; YPredScores = predict(trainedNet,dlarray(XTest,"SSCB")); [~,idx] = max(YPredScores,[],1); predictedLabels = classes(idx); figure confusionchart(trueLabels,categorical(predictedLabels), ... ColumnSummary="column-normalized", ... RowSummary="row-normalized", ... Title="Confusion Matrix for Digits Data Set"); cm = gcf; cm.Position(3) = cm.Position(3)*1.5; end
See Also
Apps
Objects
Related Topics
- Train Network in Parallel with Custom Training Loop
- Scale Up Deep Learning in Parallel, on GPUs, and in the Cloud
- Deep Learning with MATLAB on Multiple GPUs
- Use Experiment Manager to Train Networks in Parallel
- Offload Experiments as Batch Jobs to Cluster
- Use Parallel Computing Toolbox with Cloud Center Cluster in MATLAB Online (Parallel Computing Toolbox)