Visualize Activations of LSTM Network
This example shows how to investigate and visualize the features learned by LSTM networks by extracting the activations.
Load pretrained network. JapaneseVowelsNet
is a pretrained LSTM network trained on the Japanese Vowels dataset as described in [1] and [2]. It was trained on the sequences sorted by sequence length with a mini-batch size of 27.
load JapaneseVowelsNet
View the network architecture.
net.Layers
ans = 4x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax
Load the test data.
load JapaneseVowelsTestData
Visualize the first time series in a plot. Each line corresponds to a feature.
X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location',"northeastoutside")
For each time step of the sequences, get the activations output by the LSTM layer (layer 2) for that time step and update the network state.
sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits; for i = 1:sequenceLength [features(i,:),state] = predict(net,X(:,1)',Outputs="lstm"); net.State = state; end features = features';
Visualize the first 10 hidden units using a heatmap.
figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")
The heatmap shows how strongly each hidden unit activates and highlights how the activations change over time.
References
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
See Also
trainnet
| trainingOptions
| dlnetwork
| predict
| forward
| lstmLayer
| bilstmLayer
| sequenceInputLayer