Reinforcement Learning Grid World multi-figures
4 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
Reinforcement Learning
il 14 Feb 2021
Commentato: Reinforcement Learning
il 16 Feb 2021
Hello,
I did my own version of Grid World with my own obstacles (see Code below).
My Question ist: How can I simulate the trained agent in the enviroment in multiple figures?
I am using:
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(agent,env)
And getting one variation. I tried using:
for i=1:3
figure(i)
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(agent,env)
end
But it didn't work as planned.
Here my code for that. For some reason, I am getting spikes in the reward plot, although this already converged. I tried to tune some variables like LearnRate, Epsilon and DiscountFactor, but this is the best result I am getting of that:
GitterWelt = createGridWorld(7,7);
GitterWelt.CurrentState = '[1,1]';
GitterWelt.ObstacleStates = ["[5,3]";"[5,4]";"[5,5]";"[4,5]";"[3,5]"];
GitterWelt.TerminalStates = '[6,6]';
updateStateTranstionForObstacles(GitterWelt)
nS = numel(GitterWelt.States);
nA = numel(GitterWelt.Actions);
GitterWelt.R = -1*ones(nS,nS,nA);
GitterWelt.R(:,state2idx(GitterWelt,GitterWelt.TerminalStates),:) = 10;
env = rlMDPEnv(GitterWelt);
qTable = rlTable(getObservationInfo(env), getActionInfo(env));
qRep = rlQValueRepresentation(qTable, Obs_Info, Act_Info);
%% All trivial until here
qRep.Options.LearnRate = 0.2; % Alpha: This was in the example 1, but it doesn't make sense
Ag_Opts = rlQAgentOptions;
Ag_Opts.DiscountFactor = 0.9; % Gamma
Ag_Opts.EpsilonGreedyExploration.Epsilon = 0.02;
agent = rlQAgent(qRep,Ag_Opts);
Train_Opts = rlTrainingOptions;
Train_Opts.MaxEpisodes = 1000;
Train_Opts.MaxStepsPerEpisode = 40;
Train_Opts.StopTrainingCriteria = "AverageReward";
Train_Opts.StopTrainingValue = 10;
Train_Opts.Verbose = 1;
trainOpts.ScoreAveragingWindowLength = 30;
Train_Opts.Plots = "training-progress";
Train_Info = train(agent,env,Train_Opts);
0 Commenti
Risposta accettata
Emmanouil Tzorakoleftherakis
il 16 Feb 2021
Hello,
I wouldn't worry about the spikes as long as the average reward has converged. Could be the agent exploring something.
For your plotting question, the plot function for the gridworld environments has been set up with a listener callback so that it can be updated on the fly every time you call step. This means that you can only have one plot per grid world environment.
A quick workaround would be to create separate environment objects for the same grid world you created and call plot for each one. So:
function env = MyGridWorld
GitterWelt = createGridWorld(7,7);
GitterWelt.CurrentState = '[1,1]';
GitterWelt.ObstacleStates = ["[5,3]";"[5,4]";"[5,5]";"[4,5]";"[3,5]"];
GitterWelt.TerminalStates = '[6,6]';
updateStateTranstionForObstacles(GitterWelt)
nS = numel(GitterWelt.States);
nA = numel(GitterWelt.Actions);
GitterWelt.R = -1*ones(nS,nS,nA);
GitterWelt.R(:,state2idx(GitterWelt,GitterWelt.TerminalStates),:) = 10;
env = rlMDPEnv(GitterWelt);
end
and then
env1 = MyGridWorld;
env2 = MyGridWorld;
plot(env1)
plot(env2)
Hope that helps
Più risposte (0)
Vedere anche
Categorie
Scopri di più su Training and Simulation in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!