Azzera filtri
Azzera filtri

netTrained = trainnet(s​equences,t​argets,net​,lossFcn,o​ptions),se​quences包含复​数无法使用此函数

27 visualizzazioni (ultimi 30 giorni)
问题:
应用函数netTrained = trainnet(sequences,targets,net,lossFcn,options),
sequences包含复数时如何使用此函数?
函数说明里有提示可使用复数输入:This argument supports complex-valued predictors and targets.
代码:
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
numChannels = betalen;
layers = [
sequenceInputLayer(numChannels)
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
net = trainnet(XTrain,TTrain,layers,"mse",options);
报错结果:
错误使用 trainnet (第 46 行)
在层 'lstm' 期间执行失败。
出错 HDL (第 66 行)
net = trainnet(XTrain,TTrain,layers,"mse",options);
原因:
错误使用 dlarray/lstm (第 105 行)
位置 1 处的参数无效。 值必须为实数。

Risposte (1)

Paras Gupta
Paras Gupta il 18 Lug 2024 alle 4:52
Modificato: Paras Gupta il 18 Lug 2024 alle 4:56
Hi Alexander,
I understand that you are trying to use the "trainnet" function on complex-valued sequences and complex-valued targets.
You are correct in noting that the documentation indicates that the "trainnet" function can support complex-valued predictors and targets. However, the built-in loss functions provided by "trainnet" do not inherently support complex-valued targets. To address this, you will need to define a custom loss function that can handle complex values for targets.
Moreover, the "sequenceInputLayer" in your model should be configured to handle complex-valued inputs. This can be done by setting the "SplitComplexInputs" argument to true.
Below is an example of a custom loss function for complex inputs, which you can use in your training loop:
% dummy data
numSamples = 100;
numTimesteps = 10;
numChannels = 2;
realPart = randn(numSamples, numTimesteps, numChannels);
imagPart = randn(numSamples, numTimesteps, numChannels);
dataTrain = realPart + 1i * imagPart;
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
% complex target
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
% real target
% TTrain = rand(numSamples, numChannels, numTimesteps-1);
numChannels = 2;
layers = [
sequenceInputLayer(numChannels, SplitComplexInputs=true) % split Complex Inputs
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
% net = trainnet(XTrain, TTrain, layers, "mse", options);
% custom loss function passed as function handle
net = trainnet(XTrain, TTrain, layers, @complexLoss, options);
function loss = complexLoss(Y, T)
difference = Y - T;
squaredMagnitude = real(difference).^2;
loss = mean(squaredMagnitude, 'all');
end
You can refer to the following documentation links for more information on the code above:
Hope this helps with your work.

Prodotti


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by