Main Content

Custom Training Loops

Customize deep learning training loops and loss functions

If the trainingOptions function does not provide the training options that you need for your task, or custom output layers do not support the loss functions that you need, then you can define a custom training loop. For models that cannot be specified as networks of layers, you can define the model as a function. To learn more, see Define Custom Training Loops, Loss Functions, and Networks.

Funzioni

espandi tutto

dlnetworkDeep learning neural network (Da R2019b)
imagePretrainedNetworkPretrained neural network for images (Da R2024a)
resnetNetwork2-D residual neural network (Da R2024a)
resnet3dNetwork3-D residual neural network (Da R2024a)
addLayersAdd layers to neural network
removeLayersRemove layers from neural network
replaceLayerReplace layer in neural network
connectLayersConnect layers in neural network
disconnectLayersDisconnect layers in neural network
addInputLayerAdd input layer to network (Da R2022b)
initializeInitialize learnable and state parameters of a dlnetwork (Da R2021a)
networkDataLayoutDeep learning network data layout for learnable parameter initialization (Da R2022b)
setL2FactorSet L2 regularization factor of layer learnable parameter
getL2FactorGet L2 regularization factor of layer learnable parameter
setLearnRateFactorSet learn rate factor of layer learnable parameter
getLearnRateFactorGet learn rate factor of layer learnable parameter
plotPlot neural network architecture
summaryPrint network summary (Da R2022b)
analyzeNetworkAnalyze deep learning network architecture
checkLayerCheck validity of custom or function layer
isequalCheck equality of neural networks (Da R2021a)
isequalnCheck equality of neural networks ignoring NaN values (Da R2021a)
forwardCompute deep learning network output for training (Da R2019b)
predictCompute deep learning network output for inference (Da R2019b)
adamupdateUpdate parameters using adaptive moment estimation (Adam) (Da R2019b)
rmspropupdate Update parameters using root mean squared propagation (RMSProp) (Da R2019b)
sgdmupdate Update parameters using stochastic gradient descent with momentum (SGDM) (Da R2019b)
lbfgsupdateUpdate parameters using limited-memory BFGS (L-BFGS) (Da R2023a)
lbfgsStateState of limited-memory BFGS (L-BFGS) solver (Da R2023a)
dlupdate Update parameters using custom function (Da R2019b)
trainingProgressMonitorMonitor and plot training progress for deep learning custom training loops (Da R2022b)
updateInfoUpdate information values for custom training loops (Da R2022b)
recordMetricsRecord metric values for custom training loops (Da R2022b)
groupSubPlotGroup metrics in training plot (Da R2022b)
padsequencesPad or truncate sequence data to same length (Da R2021a)
minibatchqueueCreate mini-batches for deep learning (Da R2020b)
onehotencodeEncode data labels into one-hot vectors (Da R2020b)
onehotdecodeDecode probability vectors into class labels (Da R2020b)
nextObtain next mini-batch of data from minibatchqueue (Da R2020b)
resetReset minibatchqueue to start of data (Da R2020b)
shuffleShuffle data in minibatchqueue (Da R2020b)
hasdataDetermine if minibatchqueue can return mini-batch (Da R2020b)
partitionPartition minibatchqueue (Da R2020b)
dlarrayDeep learning array for customization (Da R2019b)
dlgradientCompute gradients for custom training loops using automatic differentiation (Da R2019b)
dlfevalEvaluate deep learning model for custom training loops (Da R2019b)
dimsEtichette della dimensione di dlarray (Da R2019b)
finddimFind dimensions with specified label (Da R2019b)
stripdimsRemove dlarray data format (Da R2019b)
extractdataEstrae i dati da dlarray (Da R2019b)
isdlarrayCheck if object is dlarray (Da R2020b)
crossentropyCross-entropy loss for classification tasks (Da R2019b)
l1lossL1 loss for regression tasks (Da R2021b)
l2lossL2 loss for regression tasks (Da R2021b)
huberHuber loss for regression tasks (Da R2021a)
mseHalf mean squared error (Da R2019b)
ctcConnectionist temporal classification (CTC) loss for unaligned sequence classification (Da R2021a)
dlaccelerateAccelerate deep learning function for custom training loops (Da R2021a)
AcceleratedFunctionAccelerated deep learning function (Da R2021a)
clearCacheClear accelerated deep learning function trace cache (Da R2021a)

Argomenti

Custom Training Loops

Automatic Differentiation

Deep Learning Function Acceleration

Informazioni complementari