Invalid training data for LSTM network.

9 visualizzazioni (ultimi 30 giorni)
tyler seudath
tyler seudath il 2 Dic 2021
Hi all,
I am creating a LSTM network and I am getting an error 'Invalid training data. Predictors and responses must have the same number of observations.' I converted the input training data to cell arrays and the output to a categorical one, yet I am getting an error.
Here is a sample of the code:
DataParts = zeros(size(Train1_inputX1,1), size(Train1_inputX1,2),1,2); %(4500,400,1,2)
DataParts(:,:,:,1) = real(cell2mat(Train1_inputX1));
DataParts(:,:,:,2) = imag(cell2mat(Train1_inputX1)) ;
XTrain=num2cell(reshape(DataParts, [400,1,2,4050])); %Train data
DataParts1 = zeros(size(testX1_input,1), size(testX1_input,2),1, 2);
DataParts1(:,:,:,1) = real(cell2mat(testX1_input));
DataParts1(:,:,:,2) = imag(cell2mat(testX1_input)) ;
Ttrain=num2cell(reshape(DataParts1,[400,1,2,500])); %Test data
DataParts2 = zeros(size(ValX1_input,1), size(ValX1_input,2),1, 2);
DataParts2(:,:,:,1) = real(cell2mat(ValX1_input));
DataParts2(:,:,:,2) = imag(cell2mat(ValX1_input));
Vtrain =num2cell(reshape(DataParts2,[400,1,2,450])); %450 is the number of segments %400 is the number of samples
Valoutfinal= categorical(ValX1_output); %450 values
testoutfinal = categorical(testX1_output); %500 values
Trainoutfinal= categorical(Train1_outputX1);%4050 values
%% NETWORK ARCHITECTURE
inputSize = [400 1 2];
numHiddenUnits = 800;
numClasses = 4;
layers = [ ...
sequenceInputLayer(inputSize,'Name','input')
flattenLayer('Name','flatten')
bilstmLayer(numHiddenUnits ,'OutputMode','last','Name','lstm')
fullyConnectedLayer(numClasses , 'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','classification')];
% Specify training options.
maxEpochs = 100;
miniBatchSize = 27;
options = trainingOptions('sgdm', ...
'ExecutionEnvironment','cpu', ...
'GradientThreshold',1, ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'SequenceLength','longest', ...
'Shuffle','never', ...
'Verbose',0, ...
'Plots','training-progress');
%% Train network
net = trainNetwork(Ttrain,Trainoutfinal,layers,options);
Any help is greatly appreciated.
Thanks a mil.

Risposte (0)

Tag

Prodotti


Release

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by