Main Content

La traduzione di questa pagina non è aggiornata. Fai clic qui per vedere l'ultima versione in inglese.

Classificazione di sequenze utilizzando il Deep Learning

Questo esempio mostra come classificare i dati sequenziali utilizzando una rete con memoria a breve e lungo termine (LSTM).

Per addestrare una rete neurale profonda alla classificazione di dati sequenziali, si può utilizzare una rete LSTM. Una rete LSTM consente di immettere dati sequenziali in una rete ed eseguire previsioni basate sulle singole fasi temporali dei dati sequenziali.

Questo esempio utilizza il set di dati delle vocali giapponesi descritto in [1] e [2]. Questo esempio addestra una rete LSTM a riconoscere i dati di una determinata serie temporale per oratori che rappresentano due vocali giapponesi pronunciate in successione. I dati di addestramento contengono dati di serie temporali per nove oratori. Ogni sequenza ha 12 feature e varia in lunghezza. L’insieme di dati contiene 270 osservazioni di addestramento e 370 osservazioni di test.

Caricamento dei dati sequenziali

Caricare i dati di addestramento delle vocali giapponesi. XTrain è un array di celle contenente 270 sequenze di dimensione 12 e di lunghezza variabile. Y è un vettore categoriale di etichette "1","2",...,"9", che corrispondono ai nove oratori. Le voci in XTrain sono matrici con 12 righe (una riga per ogni feature) e un numero variabile di colonne (una colonna per ciascuna fase temporale).

[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)
ans=5×1 cell array
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

Visualizzare la prima serie temporale in un grafico. Ogni riga corrisponde a una feature.

figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),Location="northeastoutside")

Figure contains an axes object. The axes object with title Training Observation 1, xlabel Time Step contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Preparazione dei dati per il riempimento

Per impostazione predefinita, durante l’addestramento, il software divide i dati di addestramento in mini-batch e riempie le sequenze in modo che abbiano la stessa lunghezza. Un riempimento eccessivo può avere un impatto negativo sulle prestazioni della rete.

Onde evitare che il processo di addestramento aggiunga un riempimento eccessivo, è possibile ordinare i dati di addestramento in base alla lunghezza della sequenza e scegliere una dimensione del mini-batch in modo che le sequenze in un mini-batch abbiano una lunghezza simile. La figura seguente mostra l'effetto del riempimento delle sequenze prima e dopo l'ordinamento dei dati.

Ottenere le lunghezze delle sequenze per ciascuna osservazione.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

Ordinare i dati in base alla lunghezza della sequenza.

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

Visualizzare le lunghezze delle sequenze ordinate in un grafico a barre.

figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

Figure contains an axes object. The axes object with title Sorted Data, xlabel Sequence, ylabel Length contains an object of type bar.

Scegliere 27 come dimensione del mini-batch per dividere i dati di addestramento uniformemente e ridurre la quantità di riempimento nei mini-batch. La figura seguente illustra il riempimento aggiunto alle sequenze.

miniBatchSize = 27;

Definizione dell’architettura di rete LSTM

Definire l’architettura di rete LSTM. Specificare che la dimensione di input sia in sequenze di dimensione 12 (la dimensione dei dati di input). Specificare un livello LSTM bidirezionale con 100 unità nascoste e generare l’ultimo elemento della sequenza. Specificare infine nove classi, includendo un livello completamente connesso di dimensione 9, seguito da un livello softmax e da un livello di classificazione.

Se si ha accesso a frequenze complete al momento della previsione, è possibile utilizzare un livello LSTM bidirezionale nella rete. Un livello LSTM bidirezionale apprende dalla sequenza completa ad ogni fase temporale. Se non si ha accesso all'intera sequenza al momento della previsione, ad esempio se si stanno prevedendo valori o una fase temporale alla volta, utilizzare invece un livello LSTM.

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 12 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         9 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

Specificare adesso le opzioni di addestramento. Specificare "adam" come solver, 1 per la soglia del gradiente e 50 come numero massimo di epoche. Per riempire i dati in modo che abbiano la stessa lunghezza delle sequenze più lunghe, specificare "longest" per la lunghezza della sequenza. Per garantire che i dati rimangano ordinati in base alla lunghezza della sequenza, specificare di non mescolare mai i dati.

Poiché i mini-batch sono piccoli e con sequenze brevi, l'addestramento è maggiormente adatto per la CPU. Impostare l’opzione ExecutionEnvironment su "cpu". Per eseguire l’addestramento su una GPU, se disponibile, impostare l’opzione ExecutionEnvironment su "auto" (questo è il valore predefinito).

options = trainingOptions("adam", ...
    ExecutionEnvironment="cpu", ...
    GradientThreshold=1, ...
    MaxEpochs=50, ...
    MiniBatchSize=miniBatchSize, ...
    SequenceLength="longest", ...
    Shuffle="never", ...
    Verbose=0, ...
    Plots="training-progress");

Addestramento della rete LSTM

Addestrare la rete LSTM con le opzioni di addestramento specificate utilizzando trainNetwork.

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (23-Mar-2023 10:39:10) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 9 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 9 objects of type patch, text, line.

Test della rete LSTM

Caricare l’insieme di test e classificare le sequenze in oratori.

Caricare i dati di test delle vocali giapponesi. XTest è un array di celle contenente 370 sequenze di dimensione 12 e di lunghezza variabile. YTest è un vettore categoriale di etichette "1","2",...,"9", che corrispondono ai nove oratori.

[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)
ans=3×1 cell array
    {12x19 double}
    {12x17 double}
    {12x19 double}

La rete LSTM net è stata addestrata utilizzando mini-batch di sequenze con lunghezza simile. Assicurarsi che i dati di test siano organizzati nello stesso modo. Ordinare i dati di test in base alla lunghezza della sequenza.

numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

Classificare i dati di test. Per ridurre la quantità di riempimento introdotta dal processo di classificazione, specificare la stessa dimensione del mini-batch utilizzato per l'addestramento. Per applicare lo stesso riempimento dei dati di addestramento, specificare "longest" per la lunghezza della sequenza.

YPred = classify(net,XTest, ...
    MiniBatchSize=miniBatchSize, ...
    SequenceLength="longest");

Calcolare la precisione della classificazione delle previsioni.

acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9676

Riferimenti

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Vedi anche

| | | |

Argomenti complementari