Main Content

Quantize Multiple-Input Network Using Image and Feature Data

This example shows how to quantize a network with multiple inputs. The network classifies handwritten digits using both image and feature input data. To learn more about multi-input networks, see Multiple-Input and Multiple-Output Networks.

Load Training Data

Load the training data. The digitTrain4DArrayData function loads the images, labels, and clockwise rotation angles of the digits data set as numeric arrays. To learn more about the digits data set used in this example, see Data Sets for Deep Learning.

[X1Train,TTrain,X2Train] = digitTrain4DArrayData;

To train the network using both the image and feature data, create a single datastore that contains the training predictors and responses. Convert the numeric arrays to datastores using arrayDatastore. Use the combine function to combine the datastores into a single datastore.

dsX1Train = arrayDatastore(X1Train,IterationDimension=4);
dsX2Train = arrayDatastore(X2Train);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);
classes = categories(TTrain);

Specify Training Options

Specify the training options.

  • Train using the SGDM optimizer.

  • Train for 15 epochs.

  • Train with a learning rate of 0.01.

  • Display the training progress in a plot.

  • Suppress the verbose output.

options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...

Train Network

Train the network using the trainDigitsNetwork function. To learn more about how to define the network architecture, see Train Network on Image and Feature Data.

net = trainDigitsNetwork(dsTrain,classes,options)

net = 
  dlnetwork with properties:

         Layers: [10×1 nnet.cnn.layer.Layer]
    Connections: [9×2 table]
     Learnables: [8×3 table]
          State: [2×3 table]
     InputNames: {'imageinput'  'features'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

Test Network

Test the classification accuracy of the network by comparing the predictions on a test set of data with the true labels. Load the test data and create a combined datastore containing the images and features.

[X1Test,TTest,X2Test] = digitTest4DArrayData;
dsX1Test = arrayDatastore(X1Test,IterationDimension=4);
dsX2Test = arrayDatastore(X2Test);
dsTTest = arrayDatastore(TTest);
dsTest = combine(dsX1Test,dsX2Test,dsTTest);

Create a minibatchqueue object to create minibatches to preprocess the data for dlnetwork prediction.

mbqTest = minibatchqueue(dsTest,...
    MiniBatchFcn=@preprocessMiniBatchTraining, ...
    OutputAsDlarray=[1 1 1], ...
    OutputEnvironment=["auto","auto","auto"], ...
    PartialMiniBatch="return", ...

Use the modelAccuracy function to evaluate the accuracy of the network on the test data set.

accuracyOriginal = modelAccuracy(net,mbqTest,classes,dsTest.numpartitions)
accuracyOriginal = 98.4600

Use the modelPredictions function to compute the predicted classes. Visualize the predictions using a confusionchart.

YTest = modelPredictions(net,mbqTest,classes);

Evaluate the classification accuracy based on the model predictions.

accuracy = mean(YTest == TTest)
accuracy = 0.9846

To observe the classification results, view some of the images with their prediction labels.

idx = randperm(size(X1Test,4),9);
for i = 1:9
    I = X1Test(:,:,:,idx(i));

    label = string(YTest(idx(i)));
    title("Predicted Label: " + label)

Quantize Network

To quantize a network with multiple inputs, the input data for the calibrate and validate functions must be a combinedDatastore or a transformedDatastore.

For validation, the datastore must output a cell array with (numInputs+1) columns, where numInputs is the number of inputs to the network. In this case, the first numInputs columns specify the predictors for each input and the last column specifies the responses.

Create calibration and validation data stores using random data from the test data set.

randomImagesCalibration = randperm(4999);
calibrationDataStore = dsTest.subset(randomImagesCalibration(1:200));
randomImagesValidation = randperm(4999);
validationDataStore = dsTest.subset(randomImagesValidation(1:100));

Create a dlquantizer object and specify the network to quantize.

quantObj = dlquantizer(net,ExecutionEnvironment="MATLAB"); 

Use the calibrate function to exercise the network with the calibration data and collect range statistics for the weights, biases, and activations at each layer.

calResults = calibrate(quantObj,calibrationDataStore)
calResults=16×5 table
    Optimized Layer Name    Network Layer Name    Learnables / Activations     MinValue       MaxValue 
    ____________________    __________________    ________________________    ___________    __________

      {'conv_Weights'}        {'conv'      }           "Weights"                 -0.28447       0.36445
      {'conv_Bias'   }        {'conv'      }           "Bias"                 -8.5358e-07    1.2699e-06
      {'fc_1_Weights'}        {'fc_1'      }           "Weights"                -0.084955      0.077845
      {'fc_1_Bias'   }        {'fc_1'      }           "Bias"                   -0.014489      0.016811
      {'fc_2_Weights'}        {'fc_2'      }           "Weights"                 -0.45607       0.40908
      {'fc_2_Bias'   }        {'fc_2'      }           "Bias"                   -0.020831      0.020135
      {'imageinput'  }        {'imageinput'}           "Activations"                    0             1
      {'features'    }        {'features'  }           "Activations"                  -45            45
      {'conv'        }        {'conv'      }           "Activations"              -1.8417        1.1134
      {'batchnorm'   }        {'batchnorm' }           "Activations"              -9.5983        10.389
      {'relu'        }        {'relu'      }           "Activations"                    0        10.389
      {'fc_1'        }        {'fc_1'      }           "Activations"              -13.472        14.063
      {'flatten'     }        {'flatten'   }           "Activations"              -13.472        14.063
      {'cat'         }        {'cat'       }           "Activations"                  -45            45
      {'fc_2'        }        {'fc_2'      }           "Activations"                -38.1        36.679
      {'softmax'     }        {'softmax'   }           "Activations"           4.1264e-31             1

Use the validate function to compare the results of the network before and after quantization using the validation data set. To validate the dlnetwork, define a dlquantizationOptions object and specify a custom metric function. The hComputeModelAccuracy metric function uses the classes from the training data to compare the predicted labels to the labels in the validation data.

dlquantOpts = dlquantizationOptions; 
dlquantOpts.MetricFcn = {@(x)hComputeModelAccuracy(x,net,validationDataStore,classes)}
dlquantOpts = 
  dlquantizationOptions with properties:

   Validation Metric Info
    MetricFcn: {@(x)hComputeModelAccuracy(x,net,validationDataStore,classes)}

   Validation Environment Info
       Target: 'host'
    Bitstream: ''

valResults = validate(quantObj,validationDataStore,dlquantOpts);

Examine the MetricResults.Result field of the validation output to view the accuracy of the quantized network and the floating-point network.

ans=2×2 table
    NetworkImplementation    MetricOutput
    _____________________    ____________

     {'Floating-Point'}          0.99    
     {'Quantized'     }          0.99    

Supporting Functions

Train Network

The trainDigitsNetwork function takes as input a CombinedDatastore, the network classes, and the training options, and trains the network using the trainnet function.

function net = trainDigitsNetwork(dsTrain, classes, options)    

% Define network 

imageInputSize = [28 28 1];
filterSize = 5;
numFilters = 16;

layers = [

lgraph = layerGraph(layers);

featInput = featureInputLayer(1,Name="features");
lgraph = addLayers(lgraph,featInput);
lgraph = connectLayers(lgraph,"features","cat/in2");
dlnet = dlnetwork(lgraph);

net = trainnet(dsTrain, dlnet,"crossentropy", options);

Mini-Batch Preprocessing Function

The preprocessMiniBatchTraining function preprocesses a mini-batch of predictors and labels for loss computation during training.

function [X1, X2, T] = preprocessMiniBatchTraining(X1Cell, X2Cell,TCell)
% Concatenate.
X1 = cat(4,X1Cell{1:end});
X2 = cat(1, X2Cell{1:end});

% Extract label data from cell and concatenate.
T = cat(2,TCell{1:end});

% One-hot encode labels.
T = onehotencode(T,1);

Evaluate Model Accuracy

The modelAccuracy function takes as input a dlnetwork object, a minibatchqueue of input data mbq, the network classes, and the number of observations and returns the accuracy.

function accuracy = modelAccuracy(net, mbq, classes, numObservations)
% This function computes the model accuracy of a dlnetwork on the minibatchque 'mbq'.

totalCorrect = 0;

classes = int32(categorical(classes));


while hasdata(mbq)
    [dlX1, dlX2, Y] = next(mbq);

    dlYPred = extractdata(predict(net, dlX1, dlX2));

    YPred = onehotdecode(dlYPred,classes,1)';
    YReal = onehotdecode(Y,classes,1)';

    miniBatchCorrect = nnz(YPred == YReal);

    totalCorrect = totalCorrect + miniBatchCorrect;

accuracy = totalCorrect / numObservations * 100;

Model Predictions Function

The modelPredictions function takes as input a dlnetwork object, a minibatchqueue of input data mbq, the network classes, and computes the model predictions by iterating over all data in the minibatchqueue object. The function uses the onehotdecode function to find the predicted class with the highest score.

function YPred = modelPredictions(net, mbq, classes)

YPred = [];


while hasdata(mbq)
    [dlX1, dlX2] = next(mbq);

    dlYPred = extractdata(predict(net, dlX1, dlX2));

    currentYPred = onehotdecode(dlYPred,classes,1)';

    YPred = cat(1, YPred, currentYPred);



Metric Function for Validation

The hComputeModelAccuracy metric function accepts as input the prediction scores, a dlnetwork object, a validation datastore, and the network classes. The function compares predicted labels to ground truth label data and returns the accuracy.

function accuracy = hComputeModelAccuracy(predictionScores, ~, dataStore, classes)
%% Computes model-level accuracy statistics
    % Load ground truth.
    tmp = readall(dataStore);
    groundTruth = tmp(:,3);
    numGroundTruth = numel(groundTruth);

    predictionScores = reshape(predictionScores, [numel(predictionScores)/numGroundTruth numGroundTruth])';
    % Compare predicted label with actual ground truth.
    predictionError = {};
    for idx=1:numGroundTruth
        [~, idy] = max(predictionScores(idx,:)); 
        yActual = classes(idy);
        predictionError{end+1} = (yActual == groundTruth{idx}); %#ok
    % Sum all prediction errors.
    predictionError = [predictionError{:}];
    accuracy = sum(predictionError)/numel(predictionError);

See Also


Related Topics