Initial State Dynamical System LSTM Network

5 visualizzazioni (ultimi 30 giorni)
Michael Hesse
Michael Hesse il 18 Nov 2020
Commentato: Michael Hesse il 19 Nov 2020
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures);
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
net = resetState(net);
xpred = x0;
for i = 1 : length(t)-1
[net, xpred(:, i+1)] = predictAndUpdateState(net, xpred(:, i));
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');
This is an example code where I want to predict the trajectory of a pendulum via LSTM neural network. How can I provide the initial state x0 into the network? If you look at the figure the second state directly jumps from x0 to the state [0, 0]'. Why does this happen?
  1 Commento
Michael Hesse
Michael Hesse il 19 Nov 2020
Here is a possible workaround. Instead of learning the next state, one can learn the difference to the next state.
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end) - x(:, 1:end-1);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures, 'Normalization', 'zscore');
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
xpred = x0;
for i = 1 : length(t)-1
[net, dxpred] = predictAndUpdateState(net, xpred(:, i));
xpred(:, i+1) = xpred(:, i) + dxpred;
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');

Accedi per commentare.

Risposte (0)

Categorie

Scopri di più su Sequence and Numeric Feature Data Workflows in Help Center e File Exchange

Prodotti


Release

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by