Main Content

This example shows how to set up a custom training loop to train a network in parallel. In this example, parallel workers train on portions of the overall mini-batch. If you have a GPU, then training happens on the GPU. During training, a `DataQueue`

object sends training progress information back to the MATLAB client.

Load the digit data set and create an image datastore for the data set. Split the datastore into training and test datastores in a randomized way.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); [imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

Determine the different classes in the training set.

classes = categories(imdsTrain.Labels); numClasses = numel(classes);

Define your network architecture and make it into a layer graph by using the `layerGraph`

function. This network architecture includes batch normalization layers, which track the mean and variance statistics of the data set. When training in parallel, combine the statistics from all of the workers at the end of each iteration step, to ensure the network state reflects the whole mini-batch. Otherwise, the network state can diverge across the workers. If you are training stateful recurrent neural networks (RNNs), for example, using sequence data that has been split into smaller sequences to train networks containing LSTM or GRU layers, you must also manage the state between the workers.

layers = [ imageInputLayer([28 28 1],'Name','input','Normalization','none') convolution2dLayer(5,20,'Name','conv1') batchNormalizationLayer('Name','bn1') reluLayer('Name','relu1') convolution2dLayer(3,20,'Padding',1,'Name','conv2') batchNormalizationLayer('Name','bn2') reluLayer('Name','relu2') convolution2dLayer(3,20,'Padding',1,'Name','conv3') batchNormalizationLayer('Name','bn3') reluLayer('Name','relu3') fullyConnectedLayer(numClasses,'Name','fc')]; lgraph = layerGraph(layers);

Create a `dlnetwork`

object from the layer graph. `dlnetwork`

objects allow for training with custom loops.

dlnet = dlnetwork(lgraph)

dlnet = dlnetwork with properties: Layers: [11×1 nnet.cnn.layer.Layer] Connections: [10×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'input'} OutputNames: {'fc'} Initialized: 1

Determine if GPUs are available for MATLAB to use with the `canUseGPU`

function.

If there are GPUs available, then train on the GPUs. Create a parallel pool with as many workers as GPUs.

If there are no GPUs available, then train on the CPUs. Create a parallel pool with the default number of workers.

if canUseGPU executionEnvironment = "gpu"; numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); else executionEnvironment = "cpu"; pool = parpool; end

Get the number of workers in the parallel pool. Later in this example, you divide the workload according to this number.

N = pool.NumWorkers;

Specify the training options.

numEpochs = 20; miniBatchSize = 128; velocity = [];

For GPU training, a recommended practice is to scale up the mini-batch size linearly with the number of GPUs, in order to keep the workload on each GPU constant. For more related advice, see Training with Multiple GPUs.

if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* N end

miniBatchSize = 512

Calculate the mini-batch size for each worker by dividing the overall mini-batch size evenly among the workers. Distribute the remainder across the first workers.

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]

`workerMiniBatchSize = `*1×4*
128 128 128 128

Initialize the training progress plot.

% Set up the training plot figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on

To send data back from the workers during training, create a `DataQueue`

object. Use `afterEach`

to set up a function, `displayTrainingProgress`

, to call each time a worker sends data. `displayTrainingProgress`

is a supporting function, defined at the end of this example, that displays the training progress information that comes from the workers.

Q = parallel.pool.DataQueue; displayFcn = @(x) displayTrainingProgress(x,lineLossTrain); afterEach(Q,displayFcn);

Train the model using a custom parallel training loop, as detailed in the following steps. To execute the code simultaneously on all the workers, use an `spmd`

block. Within the `spmd`

block, `labindex`

gives the index of the worker currently executing the code.

Before training, partition the datastore for each worker by using the `partition`

function, and set `ReadSize`

to the mini-batch size of the worker.

For each epoch, reset and shuffle the datastore with the `reset`

and `shuffle`

functions. For each iteration in the epoch:

Ensure that all workers have data available before beginning processing it in parallel, by performing a global

`and`

operation (`gop`

) on the result of the`hasdata`

function.Read a mini-batch from the datastore by using the

`read`

function, and concatenate the retrieved images into a four-dimensional array of images. Normalize the images so that the pixels take values between`0`

and`1`

.Convert the labels to a matrix of dummy variables that puts labels against observations. Dummy variables contain

`1`

for the label of the observation and`0`

otherwise.Convert the mini-batch of data to a

`dlarray`

object with the underlying type single and specify the dimension labels`'SSCB'`

(spatial, spatial, channel, batch). For GPU training, convert the data to`gpuArray`

.Compute the gradients and the loss of the network on each worker by calling

`dlfeval`

on the`modelGradients`

function. The`dlfeval`

function evaluates the helper function`modelGradients`

with automatic differentiation enabled, so`modelGradients`

can compute the gradients with respect to the loss in an automatic way.`modelGradients`

is defined at the end of the example and returns loss and gradients given a network, mini-batch of data, and true labels.To obtain the overall loss, aggregate the losses on all workers. This example uses cross entropy for the loss function, and the aggregated loss is the sum of all losses. Before aggregating, normalize each loss by multiplying by the proportion of the overall mini-batch that the worker is working on. Use

`gplus`

to add all losses together and replicate the results across workers.To aggregate and update the gradients of all workers, use the

`dlupdate`

function with the`aggregateGradients`

function.`aggregateGradients`

is a supporting function defined at the end of this example. This function uses`gplus`

to add together and replicate gradients across workers, following normalization according to the proportion of the overall mini-batch that each worker is working on.Aggregate the state of the network on all workers using the

`aggregateState`

function.`aggregateState`

is a supporting function defined at the end of this example. The batchnormalization layers in the network track the mean and variance of the data. Since the complete mini-batch is spread across multiple workers, aggregate the network state after each iteration to compute the mean and variance of the whole minibatch.After computing the final gradients, update the network learnable parameters with the

`sgdmupdate`

function.Send training progress information back to the client by using the

`send`

function with the`DataQueue`

. Use only one worker to send data, because all workers have the same loss information. To ensure that data is on the CPU, so that a client machine without a GPU can access it, use`gather`

on the`dlarray`

before sending it.

start = tic; spmd % Partition datastore. workerImds = partition(imdsTrain,N,labindex); workerImds.ReadSize = workerMiniBatchSize(labindex); workerVelocity = velocity; iteration = 0; for epoch = 1:numEpochs % Reset and shuffle the datastore. reset(workerImds); workerImds = shuffle(workerImds); % Loop over mini-batches. while gop(@and,hasdata(workerImds)) iteration = iteration + 1; % Read a mini-batch of data. [workerXBatch,workerTBatch] = read(workerImds); workerXBatch = cat(4,workerXBatch{:}); workerNumObservations = numel(workerTBatch.Label); % Normalize the images. workerXBatch = single(workerXBatch) ./ 255; % Convert the labels to dummy variables. workerY = zeros(numClasses,workerNumObservations,'single'); for c = 1:numClasses workerY(c,workerTBatch.Label==classes(c)) = 1; end % Convert the mini-batch of data to dlarray. dlworkerX = dlarray(workerXBatch,'SSCB'); % If training on GPU, then convert data to gpuArray. if executionEnvironment == "gpu" dlworkerX = gpuArray(dlworkerX); end % Evaluate the model gradients and loss on the worker. [workerGradients,dlworkerLoss,workerState] = dlfeval(@modelGradients,dlnet,dlworkerX,workerY); % Aggregate the losses on all workers. workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize; loss = gplus(workerNormalizationFactor*extractdata(dlworkerLoss)); % Aggregate the network state on all workers dlnet.State = aggregateState(workerState,workerNormalizationFactor); % Aggregate the gradients on all workers. workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor}); % Update the network parameters using the SGDM optimizer. [dlnet.Learnables,workerVelocity] = sgdmupdate(dlnet.Learnables,workerGradients,workerVelocity); end % Display training progress information. if labindex == 1 data = [epoch loss iteration toc(start)]; send(Q,gather(data)); end end end

After you train the network, you can test its accuracy.

Load the test images into memory by using `readall`

on the test datastore, concatenate them, and normalize them.

XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; YTest = imdsTest.Labels;

After the training is complete, all workers have the same complete trained network. Retrieve any of them.

dlnetFinal = dlnet{1};

To classify images using a `dlnetwork`

object, use the `predict`

function on a `dlarray`

.

`dlYPredScores = predict(dlnetFinal,dlarray(XTest,'SSCB'));`

From the predicted scores, find the class with the highest score with the `max`

function. Before you do that, extract the data from the `dlarray`

with the `extractdata`

function.

[~,idx] = max(extractdata(dlYPredScores),[],1); YPred = classes(idx);

To obtain the classification accuracy of the model, compare the predictions on the test set against the true labels.

accuracy = mean(YPred==YTest)

accuracy = 0.9910

Define a function, `modelGradients`

, to compute the gradients of the loss with respect to the learnable parameters of the network. This function computes the network outputs for a mini-batch `X`

with `forward`

and `softmax`

and calculates the loss, given the true outputs, using cross entropy. When you call this function with `dlfeval`

, automatic differentiation is enabled, and `dlgradient`

can compute the gradients of the loss with respect to the learnables automatically.

function [dlgradients,dlloss,state] = modelGradients(dlnet,dlX,dlY) [dlYPred,state] = forward(dlnet,dlX); dlYPred = softmax(dlYPred); dlloss = crossentropy(dlYPred,dlY); dlgradients = dlgradient(dlloss,dlnet.Learnables); end

Define a function to display training progress information that comes from the workers. The `DataQueue`

in this example calls this function every time a worker sends data.

function displayTrainingProgress (data,line) addpoints(line,double(data(3)),double(data(2))) D = duration(0,0,data(4),'Format','hh:mm:ss'); title("Epoch: " + data(1) + ", Elapsed: " + string(D)) drawnow end

Define a function that aggregates the gradients on all workers by adding them together. `gplus`

adds together and replicates all the gradients on the workers. Before adding them together, normalize them by multiplying them by a factor that represents the proportion of the overall mini-batch that the worker is working on. To retrieve the contents of a `dlarray`

, `u`

se `extractdata`

.

function gradients = aggregateGradients(dlgradients,factor) gradients = extractdata(dlgradients); gradients = gplus(factor*gradients); end

Define a function that aggregates the network state on all workers. The network state contains the trained batch normalization statistics of the data set. Since each worker only sees a portion of the mini-batch, aggregate the network state so that the statistics are representative of the statistics across all the data. For each mini-batch, the combined mean is calculated as a weighted average of the mean across the workers for each iteration. The combined variance is calculated according to the following formula:

$${\mathit{s}}_{\mathit{c}}^{2}=\frac{1}{\mathit{M}}\sum _{\mathit{j}=1}^{\mathit{N}}{\mathit{m}}_{\mathit{j}}\left[{\mathit{s}}_{\mathit{j}}^{2}+{\left({\stackrel{\u203e}{\mathit{x}}}_{\mathit{j}}-{\stackrel{\u203e}{\mathit{x}}}_{\mathit{c}}\right)}^{2}\right]$$

where $\mathit{N}$is the total number of workers, $\mathit{M}$is the total number of observations in a mini-batch, ${\mathit{m}}_{\mathit{j}}$ is the number of observations processed on the $\mathit{j}$th worker, ${\stackrel{\u203e}{\mathit{x}}}_{\mathit{j}}$ and ${\mathit{s}}_{\mathit{j}}^{2}$ are the mean and variance statistics calculated on that worker, and ${\stackrel{\u203e}{\mathit{x}}}_{\mathit{c}}$ is the combined mean across all workers.

function state = aggregateState(state,factor) numrows = size(state,1); for j = 1:numrows isBatchNormalizationState = state.Parameter(j) =="TrainedMean"... && state.Parameter(j+1) =="TrainedVariance"... && state.Layer(j) == state.Layer(j+1); if isBatchNormalizationState meanVal = state.Value{j}; varVal = state.Value{j+1}; % Calculate combined mean combinedMean = gplus(factor*meanVal); % Caclulate combined variance terms to sum combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2); % Update state state.Value(j) = {combinedMean}; state.Value(j+1) = {gplus(combinedVarTerm)}; end end end

`crossentropy`

| `dlarray`

| `dlfeval`

| `dlgradient`

| `dlnetwork`

| `dlupdate`

| `forward`

| `predict`

| `sgdmupdate`

| `softmax`