Main Content

La traduzione di questa pagina non è aggiornata. Fai clic qui per vedere l'ultima versione in inglese.

Visualizzazione delle attivazioni della rete LSTM

Questo esempio mostra come analizzare e visualizzare le feature apprese dalle reti LSTM tramite l'estrazione delle attivazioni.

Caricare la rete preaddestrata. JapaneseVowelsNet è una rete LSTM preaddestrata, addestrata sul set di dati delle vocali giapponesi, come descritto in [1] e [2]. La rete è stata addestrata sulle sequenze ordinate in base alla lunghezza della sequenza, con una dimensione del mini-batch di 27.

load JapaneseVowelsNet

Visualizzare l’architettura di rete.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Caricare i dati di prova.

[XTest,YTest] = japaneseVowelsTestData;

Visualizzare la prima serie temporale in un grafico. Ogni riga corrisponde a una feature.

X = XTest{1};

figure
plot(XTest{1}')
xlabel("Time Step")
title("Test Observation 1")
numFeatures = size(XTest{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes object. The axes object with title Test Observation 1, xlabel Time Step contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Per ogni passo temporale delle sequenze, ottenere le attivazioni di output dal livello LSTM (livello 2) per quel passo temporale e aggiornare lo stato della rete.

sequenceLength = size(X,2);
idxLayer = 2;
outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength
    features(:,i) = activations(net,X(:,i),idxLayer);
    [net, YPred(i)] = classifyAndUpdateState(net,X(:,i));
end

Visualizzare le prime 10 unità nascoste utilizzando una mappa di calore.

figure
heatmap(features(1:10,:));
xlabel("Time Step")
ylabel("Hidden Unit")
title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

La mappa di calore mostra l’intensità di attivazione di ciascuna unità nascosta e evidenzia come le attivazioni cambino nel tempo.

Riferimenti

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Vedi anche

| | | | |

Argomenti complementari