Training A Model From Scratch - MATLAB & Simulink

Deep Learning

Deep Learning Examples

Explore deep learning examples, and learn how you can get started in MATLAB. 

Training a Model from Scratch

In this example, we want to train a convolutional neural network (CNN) to identify handwritten digits. We will use data from the Digits data set, which contains 10,000 images of handwritten numbers 0-9. Here is a random sample of 25 handwritten numbers in the Digits data set:

By using a simple data set, we'll be able to cover all the key steps in the deep learning workflow without dealing with challenges such as processing power or datasets that are too large to fit into memory. You can apply the workflow described here to more complex deep learning problems and larger data sets.

If you are just getting started with applying deep learning, another advantage to using this data set is that you can train it without investing in an expensive GPU.

For this simple data set, with the right deep learning model and training options, it is possible to achieve almost 100% accuracy. So how do we create a model that will get us to that point?

This will be an iterative process in which we build on previous training results to figure out how to approach the training problem. The steps are as follows:

1. Accessing the Data

We begin by loading the Digits images into MATLAB. Data sets are stored in many different file types. This data is stored as a collection of image files. These lines of code will create a datastore for image data, which helps you manage the image files.


digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ...
    "nndatasets","DigitDataset");
imds = imageDatastore(digitDatasetPath, ...
    IncludeSubfolders=true,LabelSource="foldernames");

Check the size of the first image. These images are quite small – only 28 x 28 pixels.

img = readimage(imds,1);
size(img)

ans =

    28    28

Divide the data into training and validation data sets, so that the training set contains 750 images, and the validation set contains the remaining images.


numTrainFiles = 750;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,"randomize");

The next task would be image labeling, but since the Digits images come with labels, we can skip that tedious step and quickly move on to building our neural network.

2. Creating and Configuring Network Layers

We'll start by building a CNN, a common kind of deep learning network for classifying images.

About CNNS

A CNN takes an image, passes it through the network layers, and outputs a final class. The network can have tens or hundreds of layers, with each layer learning to detect different features of an image. Filters are applied to each training image at different resolutions, and the output of each convolved image is used as the input to the next layer. The filters can start as very simple features, such as brightness and edges, and increase in complexity to features that uniquely define the object as the layers progress.

To learn more about the structure of a CNN, watch:

Since we're training the CNN from scratch, we must first specify which layers it will contain and in what order.

layers = [
    imageInputLayer([28 28 1])
	
    convolution2dLayer(3,8,Padding="same")
    batchNormalizationLayer
    reluLayer
	
    maxPooling2dLayer(2,Stride=2)
	
    convolution2dLayer(3,16,Padding="same")
    batchNormalizationLayer
    reluLayer
	
    maxPooling2dLayer(2,Stride=2)
	
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer
	
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

You can learn more about all of these layers in the documentation.

3. Training the Network

First, we select training options. There are many options available. The table shows the most commonly used options.

Commonly Used Training Options

Training Option Definition Hint
Plot of training progress

The plot shows the mini-batch loss and accuracy. It includes a stop button that lets you halt network training at any point.

Use (Plots="training-progress") to plot the progress of the network as it trains.    
Max epochs

An epoch is the full pass of the training algorithm over the entire training set.

(MaxEpoch=20)

The more epochs specified, the longer the network will train, but the accuracy may improve with each epoch.

Mini-batch size

A mini-batch is a subset of the training data set that is processed at the same time.

(MiniBatchSize=64)

The larger the mini-batch, the faster the training, but the maximum size will be determined by the GPU memory. If you get a memory error when training, reduce the mini-batch size.

Learning rate This is a major parameter that controls the speed of training.  A lower learning rate can give a more accurate result, but the network may take longer to train.

We begin by specifying the training options. We then, train the network and monitor the training progress.

options = trainingOptions("sgdm", ...
    InitialLearnRate=0.1, ...
    MaxEpochs=2, ...
    Plots="training-progress");    

net = trainNetwork(imdsTrain,layers,options);

4. Checking Network Accuracy

Our goal is to have the accuracy of the model increase over time. As the network trains, the progress plot appears.

We'll try altering the training options and the network configuration.

Changing Training Options

We reduce the initial learning rate and increase the number of epochs.

InitialLearnRate=0.01
MaxEpocs=4

As a result of changing these parameters, we get a much better result—nearly 100%!

Training progress

Validation accuracy

After the network has trained, we test it on the images of the validation set.


YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)

accuracy =

 0.9968

Now that we know how accurate the trained network is, we can now use it to identify handwritten letters in online images, or even in a live video stream.

Changing the Network Configuration

Sometimes increasing the accuracy requires a deeper network and many rounds of trial and error. We can add more layers, including batch normalization layers, which will help speed up the network convergence (the point at which it responds correctly to new input). This creates a “deeper” network.

When creating a network from scratch, you are responsible for determining the network configuration. This approach gives you the most control over the network, and can produce impressive results, but it requires an understanding of the structure of a neural network and the many options for layer types and configuration.


Learn More

Follow Along with an Example

Learn how to train a convolutional neural network (CNN) to identify handwritten digits.