Visualizzazione delle attivazioni della rete LSTM
Questo esempio mostra come analizzare e visualizzare le feature apprese dalle reti LSTM tramite l'estrazione delle attivazioni.
Caricare la rete preaddestrata. JapaneseVowelsNet
è una rete LSTM preaddestrata, addestrata sul set di dati delle vocali giapponesi, come descritto in [1] e [2]. La rete è stata addestrata sulle sequenze ordinate in base alla lunghezza della sequenza, con una dimensione del mini-batch di 27.
load JapaneseVowelsNet
Visualizzare l’architettura di rete.
net.Layers
ans = 5x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax 5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes
Caricare i dati di prova.
[XTest,YTest] = japaneseVowelsTestData;
Visualizzare la prima serie temporale in un grafico. Ogni riga corrisponde a una feature.
X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location','northeastoutside')
Per ogni passo temporale delle sequenze, ottenere le attivazioni di output dal livello LSTM (livello 2) per quel passo temporale e aggiornare lo stato della rete.
sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits; for i = 1:sequenceLength features(:,i) = activations(net,X(:,i),idxLayer); [net, YPred(i)] = classifyAndUpdateState(net,X(:,i)); end
Visualizzare le prime 10 unità nascoste utilizzando una mappa di calore.
figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")
La mappa di calore mostra l’intensità di attivazione di ciascuna unità nascosta e evidenzia come le attivazioni cambino nel tempo.
Riferimenti
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Vedi anche
trainNetwork
| trainingOptions
| lstmLayer
| bilstmLayer
| sequenceInputLayer
| activations