Classify error: requires 3 arguments
3 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
Hello All I have trained an LSTM model to classify EMG signals ( one dimensional time series) to produce a class prediction. Now when trying to test the trained LSTM on a test signal , classify produces error of requiring 3 arguments. No matter how I changed the shape of the test signal nothing helped. Also predict produced error results. Could you please help.
% Training code:
% LSTM-1D classification using raw EMG signal
%
% Data path
path = '/home/ubuntu/Desktop/EMG data analysis/EMG signal Matlab'
parameters
numHiddenUnits = 120;
numClasses = 8;
numChannels = 1
% Now prepare training/lables dataset for LSTM training
% Assuming sorted_emg_data is your sorted array with data and labels
% Extract the data for training
XTrain = cellfun(@(c) c.signal, sorted_emg_data(:, 1), 'UniformOutput', false);
% Extract the labels for training
TTrain = sorted_emg_data(:, 2);
% Convert the labels to a categorical array
TTrain = categorical(TTrain);
% Now XTrain contains all the EMG signals and TTrain contains the corresponding labels
% Now training the LSTM model
numHiddenUnits = 120;
numClasses = 8;
numChannels = 1
% Now define your layers with the correct number of output classes
layers = [ ...
sequenceInputLayer(numChannels)
bilstmLayer(numHiddenUnits, 'OutputMode', 'last')
fullyConnectedLayer(numClasses) % Make sure this matches the number of unique classes in TTrain
softmaxLayer
classificationLayer];
% Define your training options (make sure MiniBatchSize is appropriate for your dataset)
options = trainingOptions('adam', ...
'MaxEpochs', 50, ...
'MiniBatchSize', 3, ... % Adjust based on your hardware capabilities
'InitialLearnRate', 0.01, ...
'GradientThreshold', 1, ...
'Verbose', 0, ...
'Plots', 'training-progress');
% Train the network
net = trainNetwork(XTrain, TTrain, layers, options);
% Testing code
%
% Loading the network
net = load ("lstm_trained_model.mat")
% Loading data
test = load ('emg_signal_3.mat')
net.layers
length (test)
length (signal) % signal directly loaded
% Classification
pred = classify(net, test);
Risposte (1)
Cris LaPierre
il 2 Gen 2024
I cannot duplicate your error. I used this example to create a sample data set. I then trained that data using your code, and then tested it using the code in the pdfs. My conclusion is there is nothing wrong with your code. Without more details, I don't know what more we can do to help.
Here are the results I obtained when running the model on test data using the code from your pdfs.

0 Commenti
Vedere anche
Categorie
Scopri di più su Measurements and Feature Extraction in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!