classifyAndUpdateState

Classify data using a trained recurrent neural network and update the network state

You can make predictions using a trained deep learning network on either a CPU or GPU. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher. Specify the hardware requirements using the 'ExecutionEnvironment' name-value pair argument.

Syntax

[updatedNet,YPred] = classifyAndUpdateState(recNet,sequences)
[updatedNet,YPred] = classifyAndUpdateState(___,Name,Value)
[updatedNet,YPred,scores] = classifyAndUpdateState(___)

Description

example

[updatedNet,YPred] = classifyAndUpdateState(recNet,sequences) classifies the data in sequences using the trained recurrent neural network recNet and updates the network state.

This function supports recurrent neural networks only. The input recNet must have at least one recurrent layer.

[updatedNet,YPred] = classifyAndUpdateState(___,Name,Value) uses any of the arguments in the previous syntaxes and additional options specified by one or more Name,Value pair arguments. For example, 'MiniBatchSize',27 classifies data using mini-batches of size 27

Classify and Update Network State

example

[updatedNet,YPred,scores] = classifyAndUpdateState(___) uses any of the arguments in the previous syntaxes, returns a matrix of classification scores, and updates the network state.

Tip

When making predictions with sequences of different lengths, the mini-batch size can impact the amount of padding added to the input data which can result in different predicted values. Try using different values to see which works best with your network. To specify mini-batch size and padding options, use the 'MiniBatchSize' and 'SequenceLength' options.

Examples

collapse all

Classify data using a recurrent neural network and update the network state.

To reproduce the results in this example, set rng to 'default'.

rng('default')

Load JapaneseVowelsNet, a pretrained long short-term memory (LSTM) network trained on the Japanese Vowels data set as described in [1] and [2]. This network was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load JapaneseVowelsNet

View the network architecture.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Load the test data.

load JapaneseVowelsTest

Loop over the time steps in a sequence. Classify each time step and update the network state.

X = XTest{94};
numTimeSteps = size(X,2);
for i = 1:numTimeSteps
    v = X(:,i);
    [net,label,score] = classifyAndUpdateState(net,v);
    labels(i) = label;
end

Plot the predicted labels in a stair plot. The plot shows how the predictions change between time steps.

figure
stairs(labels, '-o')
xlim([1 numTimeSteps])
xlabel("Time Step")
ylabel("Predicted Class")
title("Classification Over Time Steps")

Compare the predictions with the true label. Plot a horizontal line showing the true label of the observation.

trueLabel = YTest(94)
trueLabel = categorical
     3 

hold on
line([1 numTimeSteps],[trueLabel trueLabel], ...
    'Color','red', ...
    'LineStyle','--')
legend(["Prediction" "True Label"])

Input Arguments

collapse all

Trained recurrent neural network, specified as a SeriesNetwork object. You can get a trained network by importing a pretrained network or by training your own network using the trainNetwork function.

recNet is a recurrent neural network. It must have at least one recurrent layer (for example, an LSTM network).

Sequence or time series data, specified as an N-by-1 cell array of numeric arrays, where N is the number of observations, a numeric array representing a single sequence, or a datastore.

For cell array or numeric array input, the dimensions of the numeric arrays containing the sequences depend on the type of data.

InputDescription
Vector sequencesc-by-s matrices, where c is the number of features of the sequences and s is the sequence length.
2-D image sequencesh-by-w-by-c-by-s arrays, where h, w, and c correspond to the height, width, and number of channels of the images, respectively, and s is the sequence length.
3-D image sequencesh-by-w-by-d-by-c-by-s, where h, w, d, and c correspond to the height, width, depth, and number of channels of the 3-D images, respectively, and s is the sequence length.

For datastore input, the datastore must return data as a cell array of sequences or a table whose first column contains sequences. The dimensions of the sequence data must correspond to the table above.

Name-Value Pair Arguments

Specify optional comma-separated pairs of Name,Value arguments. Name is the argument name and Value is the corresponding value. Name must appear inside quotes. You can specify several name and value pair arguments in any order as Name1,Value1,...,NameN,ValueN.

Example: [updatedNet, YPred] = classifyAndUpdateState(recNet,C,'MiniBatchSize',27) classifies data using mini-batches of size 27.

Size of mini-batches to use for prediction, specified as a positive integer. Larger mini-batch sizes require more memory, but can lead to faster predictions.

When making predictions with sequences of different lengths, the mini-batch size can impact the amount of padding added to the input data which can result in different predicted values. Try using different values to see which works best with your network. To specify mini-batch size and padding options, use the 'MiniBatchSize' and 'SequenceLength' options.

Example: 'MiniBatchSize',256

Performance optimization, specified as the comma-separated pair consisting of 'Acceleration' and one of the following:

  • 'auto' — Automatically apply a number of optimizations suitable for the input network and hardware resource.

  • 'none' — Disable all acceleration.

The default option is 'auto'.

Using the 'Acceleration' option 'auto' can offer performance benefits, but at the expense of an increased initial run time. Subsequent calls with compatible parameters are faster. Use performance optimization when you plan to call the function multiple times using new input data.

Example: 'Acceleration','auto'

Hardware resource, specified as the comma-separated pair consisting of 'ExecutionEnvironment' and one of the following:

  • 'auto' — Use a GPU if one is available; otherwise, use the CPU.

  • 'gpu' — Use the GPU. Using a GPU requires Parallel Computing Toolbox and a CUDA enabled NVIDIA GPU with compute capability 3.0 or higher. If Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error.

  • 'cpu' — Use the CPU.

Example: 'ExecutionEnvironment','cpu'

Option to pad, truncate, or split input sequences, specified as one of the following:

  • 'longest' — Pad sequences in each mini-batch to have the same length as the longest sequence. This option does not discard any data, though padding can introduce noise to the network.

  • 'shortest' — Truncate sequences in each mini-batch to have the same length as the shortest sequence. This option ensures that no padding is added, at the cost of discarding data.

  • Positive integer — For each mini-batch, pad the sequences to the nearest multiple of the specified length that is greater than the longest sequence length in the mini-batch, and then split the sequences into smaller sequences of the specified length. If splitting occurs, then the software creates extra mini-batches. Use this option if the full sequences do not fit in memory. Alternatively, try reducing the number of sequences per mini-batch by setting the 'MiniBatchSize' option to a lower value.

If you specify the sequence length as a positive integer, then the software processes the smaller sequences in consecutive iterations. The network updates the network state between the split sequences.

The software pads and truncates the sequences on the right. To learn more about the effect of padding, truncating, and splitting the input sequences, see Sequence Padding, Truncation, and Splitting.

Example: 'SequenceLength','shortest'

Value by which to pad input sequences, specified as a scalar. The option is valid only when SequenceLength is 'longest' or a positive integer. Do not pad sequences with NaN, because doing so can propagate errors throughout the network.

Example: 'SequencePaddingValue',-1

Output Arguments

collapse all

Updated network, returned as a SeriesNetwork object.

Predicted class labels, returned as a categorical vector, or a cell array of categorical vectors. The format of YPred depends on the type of problem.

The following table describes the format of YPred.

TaskFormat
Sequence-to-label classificationN-by-1 categorical vector of labels, where N is the number of observations.
Sequence-to-sequence classification

N-by-1 cell array of categorical sequences of labels, where N is the number of observations. Each sequence has the same number of time steps as the corresponding input sequence.

For sequence-to-sequence classification problems with one observation, sequences can be a matrix. In this case, YPred is a categorical sequence of labels.

Predicted class scores, returned as a matrix or a cell array of matrices. The format of scores depends on the type of problem.

The following table describes the format of scores.

TaskFormat
Sequence-to-label classificationN-by-K matrix, where N is the number of observations, and K is the number of classes.
Sequence-to-sequence classification

N-by-1 cell array of matrices, where N is the number of observations. The sequences are matrices with K rows, where K is the number of responses. Each sequence has the same number of time steps as the corresponding input sequence.

For sequence-to-sequence classification problems with one observation, sequences can be a matrix. In this case, scores is a matrix of predicted class scores.

Algorithms

All functions for deep learning training, prediction, and validation in Deep Learning Toolbox™ perform computations using single-precision, floating-point arithmetic. Functions for deep learning include trainNetwork, predict, classify, and activations. The software uses single-precision arithmetic when you train networks using both CPUs and GPUs.

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Introduced in R2017b