Come iniziare a utilizzare Deep Network Designer
Questo esempio mostra come creare una rete neurale ricorrente semplice per la classificazione di sequenze di Deep Learning utilizzando Deep Network Designer.
Per addestrare una rete neurale profonda alla classificazione di dati sequenziali, si può utilizzare una rete LSTM. Una rete LSTM consente di immettere dati sequenziali in una rete ed eseguire previsioni basate sulle singole fasi temporali dei dati sequenziali.
Caricamento dei dati sequenziali
Caricare i dati di esempio da WaveformData
. Per accedere a questi dati, aprire l'esempio come script live. Questi dati contengono forme d'onda di quattro classi: seno, quadrato, triangolo e dente di sega. Questo esempio addestra una rete neurale LSTM a riconoscere il tipo di forma d'onda data una serie di dati temporali. Ciascuna sequenza ha tre canali e varia in lunghezza.
load WaveformData
Visualizzare alcune delle sequenze in un grafico.
numChannels = size(data{1},2); classNames = categories(labels); figure tiledlayout(2,2) for i = 1:4 nexttile stackedplot(data{i},DisplayLabels="Channel "+string(1:numChannels)) xlabel("Time Step") title("Class: " + string(labels(i))) end
Suddividere i dati in un set di addestramento contenente l'80% dei dati e in un set di convalida e in uno di test, contenenti ciascuno il 10% dei dati. Per suddividere i dati, utilizzare la funzione trainingPartitions
. Per accedere a questa funzione, aprire l'esempio come script live.
numObservations = numel(data); [idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]); XTrain = data(idxTrain); TTrain = labels(idxTrain); XValidation = data(idxValidation); TValidation = labels(idxValidation); XTest = data(idxTest); TTest = labels(idxTest);
Definizione dell’architettura di rete
Per costruire la rete, utilizzare l'applicazione Deep Network Designer.
deepNetworkDesigner
Per creare una rete sequenziale, nella sezione Sequence Networks (Reti sequenziali), fermarsi su Sequence to Label (Sequenza da etichettare) e fare clic su Open. In questo modo si apre una rete precostituita adatta per problemi di classificazione da sequenza a etichetta.
Deep Network Designer mostra la rete precostruita.
È possibile adattare facilmente questa rete sequenziale al set di dati Waveform.
Selezionare il livello di input della sequenza input
e impostare InputSize (Dimensione dell'input) su 3, in modo che corrisponda al numero di canali.
Selezionare il livello completamente connesso fc
e impostare OutputSize (Dimensione dell'output) su 4, in modo che corrisponda al numero di classi.
Per verificare che la rete sia pronta per l’addestramento, fare clic su Analyze (Analizza). Il Deep Learning Network Analyzer non riporta errori o avvisi, quindi la rete è pronta per l'addestramento. Per esportare la rete, fare clic su Export (Esporta). L'applicazione salva la rete nella variabile net_1
.
Specificazione delle opzioni di addestramento
Specificare le opzioni di addestramento. La scelta tra le opzioni richiede un'analisi empirica.
options = trainingOptions("adam", ... MaxEpochs=500, ... InitialLearnRate=0.0005, ... GradientThreshold=1, ... ValidationData={XValidation,TValidation}, ... Shuffle = "every-epoch", ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
Addestramento di reti neurali
Addestrare la rete neurale utilizzando la funzione trainnet
. Poiché l'obiettivo è la classificazione, specificare la perdita di entropia incrociata.
net = trainnet(XTrain,TTrain,net_1,"crossentropy",options);
Test della rete neurale
Per testare la rete neurale, classificare i dati di test e calcolare la precisione della classificazione.
Fare previsioni utilizzando la funzione minibatchpredict
e convertire i punteggi in etichette utilizzando la funzione scores2label
.
scores = minibatchpredict(net,XTest); YTest = scores2label(scores,classNames);
Calcolare la precisione della classificazione. La precisione è la percentuale di etichette previste correttamente.
acc = mean(YTest == TTest)
acc = 0.8300
Visualizzare le previsioni in un grafico di confusione.
figure confusionchart(TTest,YTest)