Main Content

Assemble Multiple-Output Network for Prediction

This example shows how to assemble a multiple output network for prediction.

Instead of using the dlnetwork object for prediction, you can assemble the network into a DAGNetwork ready for prediction using the assembleNetwork function. This lets you use the predict function with other data types such as datastores.

Load Model Function and Parameters

Load the model parameters from the MAT file dlnetDigits.mat. The MAT file contains a dlnetwork object that predicts both the scores for categorical labels and numeric angles of rotation of images of digits, and the corresponding class names.

s = load("dlnetDigits.mat");
net = s.net;
classNames = s.classNames;

Assemble Network for Prediction

Extract the layer graph from the dlnetwork object using the layerGraph function.

lgraph = layerGraph(net);

The layer graph does not include output layers. Add a classification layer and a regression layer to the layer graph using the addLayers and connectLayers functions.

layers = classificationLayer(Classes=classNames,Name="coutput");
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,"softmax","coutput");

layers = regressionLayer(Name="routput");
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,"fc2","routput");

View a plot of the network.

figure
plot(lgraph)

Figure contains an axes object. The axes object contains an object of type graphplot.

Assemble the network using the assembleNetwork function.

net = assembleNetwork(lgraph)
net = 
  DAGNetwork with properties:

         Layers: [19x1 nnet.cnn.layer.Layer]
    Connections: [19x2 table]
     InputNames: {'in'}
    OutputNames: {'coutput'  'routput'}

Make Predictions on New Data

Load the test data.

[XTest,T1Test,T2Test] = digitTest4DArrayData;

To make predictions using the assembled network, use the predict function. To return categorical labels for the classification output, set the ReturnCategorical option to true.

[Y1Test,Y2Test] = predict(net,XTest,ReturnCategorical=true);

Evaluate the classification accuracy.

accuracy = mean(Y1Test==T1Test)
accuracy = 0.9870

Evaluate the regression accuracy.

angleRMSE = sqrt(mean((Y2Test - T2Test).^2))
angleRMSE = single
    6.0091

View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Test(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],"r--")
    
    thetaValidation = T2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],"g--")
    
    hold off
    label = string(Y1Test(idx(i)));
    title("Label: " + label)
end

Figure contains 9 axes objects. Axes object 1 with title Label: 8 contains 3 objects of type image, line. Axes object 2 with title Label: 9 contains 3 objects of type image, line. Axes object 3 with title Label: 1 contains 3 objects of type image, line. Axes object 4 with title Label: 9 contains 3 objects of type image, line. Axes object 5 with title Label: 6 contains 3 objects of type image, line. Axes object 6 with title Label: 0 contains 3 objects of type image, line. Axes object 7 with title Label: 2 contains 3 objects of type image, line. Axes object 8 with title Label: 5 contains 3 objects of type image, line. Axes object 9 with title Label: 9 contains 3 objects of type image, line.

See Also

| | | | | |

Related Topics