Define Custom Training Loops, Loss Functions, and Networks
For most deep learning tasks, you can use a pretrained neural network and adapt it to your own
data. For an example showing how to use transfer learning to retrain a convolutional neural
network to classify a new set of images, see Train Deep Learning Network to Classify New Images. Alternatively, you can create and train
neural networks from scratch using layerGraph
objects with the trainNetwork
and trainingOptions
functions.
If the trainingOptions
function does not provide the training options that you need for your task, then you can create a custom training loop using automatic differentiation. To learn more, see Define Deep Learning Network for Custom Training Loops.
If Deep Learning Toolbox™ does not provide the layers you need for your task (including output layers that specify loss functions), then you can create a custom layer. To learn more, see Define Custom Deep Learning Layers. For loss functions that cannot be specified using an output layer, you can specify the loss in a custom training loop. To learn more, see Specify Loss Functions. For networks that cannot be created using layer graphs, you can define custom networks as a function. To learn more, see Define Network as Model Function.
For more information about which training method to use for which task, see Train Deep Learning Model in MATLAB.
Define Deep Learning Network for Custom Training Loops
Define Network as dlnetwork
Object
For most tasks, you can control the training algorithm details using the trainingOptions
and trainNetwork
functions. If the trainingOptions
function does not provide the options you need for your task
(for example, a custom learning rate schedule), then you can define your own custom training
loop using a dlnetwork
object. A dlnetwork
object allows you to train a network specified as a layer graph
using automatic differentiation.
For networks specified as a layer graph, you can create a
dlnetwork
object from the layer graph by using the
dlnetwork
function
directly.
net = dlnetwork(lgraph);
For a list of layers supported by dlnetwork
objects, see the
Supported Layers section of the
dlnetwork
page. For an example
showing how to train a network with a custom learning rate schedule, see Train Network Using Custom Training Loop.
Define Network as Model Function
For architectures that cannot be created using layer graphs (for example, a
Siamese network that requires shared weights), you can define the model as a
function of the form [Y1,...,YM] = model(parameters,X1,...,XN)
,
where parameters
contains the network parameters,
X1,...,XN
corresponds to the input data for the
N
model inputs, and Y1,...,YM
corresponds
to the M
model outputs. To train a deep learning model defined as
a function, use a custom training loop. For an example, see Train Network Using Model Function.
When you define a deep learning model as a function, you must manually initialize the layer weights. For more information, see Initialize Learnable Parameters for Model Function.
If you define a custom network as a function, then the model function must support
automatic differentiation. You can use the following deep learning operations. The
functions listed here are only a subset. For a complete list of functions that
support dlarray
input, see List of Functions with dlarray Support.
Function | Description |
---|---|
attention | The attention operation focuses on parts of the input using weighted multiplication operations. |
avgpool | The average pooling operation performs downsampling by dividing the input into pooling regions and computing the average value of each region. |
batchnorm | The batch normalization operation normalizes the input data
across all observations for each channel independently. To speed up training of the
convolutional neural network and reduce the sensitivity to network initialization, use batch
normalization between convolution and nonlinear operations such as relu . |
crossentropy | The cross-entropy operation computes the cross-entropy loss between network predictions and target values for single-label and multi-label classification tasks. |
crosschannelnorm | The cross-channel normalization operation uses local responses
in different channels to normalize each activation. Cross-channel normalization typically
follows a relu operation.
Cross-channel normalization is also known as local response normalization. |
ctc | The CTC operation computes the connectionist temporal classification (CTC) loss between unaligned sequences. |
dlconv | The convolution operation applies sliding filters to the input
data. Use the dlconv function for deep learning convolution, grouped
convolution, and channel-wise separable convolution. |
dlode45 | The neural ordinary differential equation (ODE) operation returns the solution of a specified ODE. |
dltranspconv | The transposed convolution operation upsamples feature maps. |
embed | The embed operation converts numeric indices to numeric vectors, where the indices correspond to discrete data. Use embeddings to map discrete data such as categorical values or words to numeric vectors. |
fullyconnect | The fully connect operation multiplies the input by a weight matrix and then adds a bias vector. |
gelu | The Gaussian error linear unit (GELU) activation operation weights the input by its probability under a Gaussian distribution. |
groupnorm | The group normalization operation normalizes the input data
across grouped subsets of channels for each observation independently. To speed up training of
the convolutional neural network and reduce the sensitivity to network initialization, use group
normalization between convolution and nonlinear operations such as relu . |
gru | The gated recurrent unit (GRU) operation allows a network to learn dependencies between time steps in time series and sequence data. |
huber | The Huber operation computes the Huber loss between network predictions and target values for regression tasks. When the 'TransitionPoint' option is 1, this is also known as smooth L1 loss. |
instancenorm | The instance normalization operation normalizes the input data
across each channel for each observation independently. To improve the convergence of training
the convolutional neural network and reduce the sensitivity to network hyperparameters, use
instance normalization between convolution and nonlinear operations such as relu . |
l1loss | The L1 loss operation computes the
L1 loss given network predictions and target values. When the
Reduction option is "sum" and the
NormalizationFactor option is "batch-size" , the
computed value is known as the mean absolute error (MAE). |
l2loss | The L2 loss operation computes the
L2 loss (based on the squared L2 norm) given
network predictions and target values. When the Reduction option is
"sum" and the NormalizationFactor option is
"batch-size" , the computed value is known as the mean squared error
(MSE). |
layernorm | The layer normalization operation normalizes the input data across all channels for each observation independently. To speed up training of recurrent and multilayer perceptron neural networks and reduce the sensitivity to network initialization, use layer normalization after the learnable operations, such as LSTM and fully connect operations. |
leakyrelu | The leaky rectified linear unit (ReLU) activation operation performs a nonlinear threshold operation, where any input value less than zero is multiplied by a fixed scale factor. |
lstm | The long short-term memory (LSTM) operation allows a network to learn long-term dependencies between time steps in time series and sequence data. |
maxpool | The maximum pooling operation performs downsampling by dividing the input into pooling regions and computing the maximum value of each region. |
maxunpool | The maximum unpooling operation unpools the output of a maximum pooling operation by upsampling and padding with zeros. |
mse | The half mean squared error operation computes the half mean squared error loss between network predictions and target values for regression tasks. |
onehotdecode | The one-hot decode operation decodes probability vectors, such as the output of a classification network, into classification labels. The input |
relu | The rectified linear unit (ReLU) activation operation performs a nonlinear threshold operation, where any input value less than zero is set to zero. |
sigmoid | The sigmoid activation operation applies the sigmoid function to the input data. |
softmax | The softmax activation operation applies the softmax function to the channel dimension of the input data. |
Specify Loss Functions
When you use a custom training loop, you must calculate the loss in the model gradients function. Use the loss value when computing gradients for updating the network weights. To compute the loss, you can use the following functions.
Function | Description |
---|---|
softmax | The softmax activation operation applies the softmax function to the channel dimension of the input data. |
sigmoid | The sigmoid activation operation applies the sigmoid function to the input data. |
crossentropy | The cross-entropy operation computes the cross-entropy loss between network predictions and target values for single-label and multi-label classification tasks. |
l1loss | The L1 loss operation computes the
L1 loss given network predictions and target values. When the
Reduction option is "sum" and the
NormalizationFactor option is "batch-size" , the
computed value is known as the mean absolute error (MAE). |
l2loss | The L2 loss operation computes the
L2 loss (based on the squared L2 norm) given
network predictions and target values. When the Reduction option is
"sum" and the NormalizationFactor option is
"batch-size" , the computed value is known as the mean squared error
(MSE). |
huber | The Huber operation computes the Huber loss between network predictions and target values for regression tasks. When the 'TransitionPoint' option is 1, this is also known as smooth L1 loss. |
mse | The half mean squared error operation computes the half mean squared error loss between network predictions and target values for regression tasks. |
ctc | The CTC operation computes the connectionist temporal classification (CTC) loss between unaligned sequences. |
Alternatively, you can use a custom loss function by creating a function of the form
loss = myLoss(Y,T)
, where Y
and
T
correspond to the network predictions and targets,
respectively, and loss
is the returned loss.
For an example showing how to train a generative adversarial network (GAN) that generates images using a custom loss function, see Train Generative Adversarial Network (GAN).
Update Learnable Parameters Using Automatic Differentiation
When you train a deep learning model with a custom training loop, the software minimizes the loss with respect to the learnable parameters. To minimize the loss, the software uses the gradients of the loss with respect to the learnable parameters. To calculate these gradients using automatic differentiation, you must define a model gradients function.
Define Model Loss Function
For a model specified as a dlnetwork
object, create a function of the form
[loss,gradients] = modelLoss(net,X,T)
, where net
is the network, X
is the network input, T
contains the
targets, and loss
and gradients
are the returned loss
and gradients, respectively. Optionally, you can pass extra arguments to the gradients
function (for example, if the loss function requires extra information), or return extra
arguments (for example, the updated network state).
For a model specified as a function, create a function of the form [loss,gradients] =
modelLoss(parameters,X,T)
, where parameters
contains the
learnable parameters, X
is the model input, T
contains
the targets, and loss
and gradients
are the returned
loss and gradients, respectively. Optionally, you can pass extra arguments to the gradients
function (for example, if the loss function requires extra information), or return extra
arguments (for example, the updated model state).
To learn more about defining model loss functions for custom training loops, see Define Model Loss Function for Custom Training Loop.
Update Learnable Parameters
To evaluate the model loss function using automatic differentiation, use the
dlfeval
function, which evaluates a function with automatic
differentiation enabled. For the first input of dlfeval
, pass the model
loss function specified as a function handle. For the following inputs, pass the required
variables for the model loss function. For the outputs of the dlfeval
function, specify the same outputs as the model loss function.
To update the learnable parameters using the gradients, you can use the following functions.
Function | Description |
---|---|
adamupdate | Update parameters using adaptive moment estimation (Adam) |
rmspropupdate | Update parameters using root mean squared propagation (RMSProp) |
sgdmupdate | Update parameters using stochastic gradient descent with momentum (SGDM) |
lbfgsupdate | Update parameters using limited-memory BFGS (L-BFGS) |
dlupdate | Update parameters using custom function |
See Also
dlarray
| dlgradient
| dlfeval
| dlnetwork
Related Topics
- Train Generative Adversarial Network (GAN)
- Train Network Using Custom Training Loop
- Specify Training Options in Custom Training Loop
- Define Model Loss Function for Custom Training Loop
- Update Batch Normalization Statistics in Custom Training Loop
- Update Batch Normalization Statistics Using Model Function
- Make Predictions Using dlnetwork Object
- Make Predictions Using Model Function
- Train Network Using Model Function
- Initialize Learnable Parameters for Model Function
- Train Deep Learning Model in MATLAB
- Define Custom Deep Learning Layers
- List of Functions with dlarray Support
- Automatic Differentiation Background
- Use Automatic Differentiation In Deep Learning Toolbox