Main Content

Come iniziare a utilizzare Deep Network Designer

Questo esempio mostra come creare una rete neurale ricorrente semplice per la classificazione di sequenze di Deep Learning utilizzando Deep Network Designer.

Per addestrare una rete neurale profonda alla classificazione di dati sequenziali, si può utilizzare una rete LSTM. Una rete LSTM consente di immettere dati sequenziali in una rete ed eseguire previsioni basate sulle singole fasi temporali dei dati sequenziali.

Caricamento dei dati sequenziali

Caricare i dati di esempio da WaveformData. Per accedere a questi dati, aprire l'esempio come script live. Questi dati contengono forme d'onda di quattro classi: seno, quadrato, triangolo e dente di sega. Questo esempio addestra una rete neurale LSTM a riconoscere il tipo di forma d'onda data una serie di dati temporali. Ciascuna sequenza ha tre canali e varia in lunghezza.

load WaveformData 

Visualizzare alcune delle sequenze in un grafico.

numChannels = size(data{1},2);
classNames = categories(labels);

figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{i},DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(i)))
end

Suddividere i dati in un set di addestramento contenente l'80% dei dati e in un set di convalida e in uno di test, contenenti ciascuno il 10% dei dati. Per suddividere i dati, utilizzare la funzione trainingPartitions. Per accedere a questa funzione, aprire l'esempio come script live.

numObservations = numel(data);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]);

XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XValidation = data(idxValidation);
TValidation = labels(idxValidation);

XTest = data(idxTest);
TTest = labels(idxTest);

Definizione dell’architettura di rete

Per costruire la rete, utilizzare l'applicazione Deep Network Designer.

deepNetworkDesigner

Per creare una rete sequenziale, nella sezione Sequence Networks (Reti sequenziali), fermarsi su Sequence to Label (Sequenza da etichettare) e fare clic su Open. In questo modo si apre una rete precostituita adatta per problemi di classificazione da sequenza a etichetta.

Deep Network Designer mostra la rete precostruita.

È possibile adattare facilmente questa rete sequenziale al set di dati Waveform.

Selezionare il livello di input della sequenza input e impostare InputSize (Dimensione dell'input) su 3, in modo che corrisponda al numero di canali.

Selezionare il livello completamente connesso fc e impostare OutputSize (Dimensione dell'output) su 4, in modo che corrisponda al numero di classi.

Per verificare che la rete sia pronta per l’addestramento, fare clic su Analyze (Analizza). Il Deep Learning Network Analyzer non riporta errori o avvisi, quindi la rete è pronta per l'addestramento. Per esportare la rete, fare clic su Export (Esporta). L'applicazione salva la rete nella variabile net_1.

Specificazione delle opzioni di addestramento

Specificare le opzioni di addestramento. La scelta tra le opzioni richiede un'analisi empirica.

options = trainingOptions("adam", ...
    MaxEpochs=500, ...
    InitialLearnRate=0.0005, ...
    GradientThreshold=1, ...
    ValidationData={XValidation,TValidation}, ...
    Shuffle = "every-epoch", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

Addestramento di reti neurali

Addestrare la rete neurale utilizzando la funzione trainnet. Poiché l'obiettivo è la classificazione, specificare la perdita di entropia incrociata.

net = trainnet(XTrain,TTrain,net_1,"crossentropy",options);

Test della rete neurale

Per testare la rete neurale, classificare i dati di test e calcolare la precisione della classificazione.

Fare previsioni utilizzando la funzione minibatchpredict e convertire i punteggi in etichette utilizzando la funzione scores2label.

scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);

Calcolare la precisione della classificazione. La precisione è la percentuale di etichette previste correttamente.

acc = mean(YTest == TTest)
acc = 0.8300

Visualizzare le previsioni in un grafico di confusione.

figure
confusionchart(TTest,YTest)

Vedi anche

Argomenti complementari