Compare Agents on the Discrete Cart-Pole Environment
This example shows how to create and train frequently used default agents on a discrete action space cart-pole environment. This environment represents a pole attached to an unactuated joint on a cart, which moves along a frictionless track. The agent can apply a force on the cart and its training goal is to balance the pole upright using minimal control effort. The example plots performance metrics such as the total training time and the total reward for each trained agent. The results that the agents obtain in this environment, with the selected initial conditions and random number generator seed, do not necessarily imply that specific agents are better than others. Also, note that the training times depend on the computer and operating system you use to run the example, and on other processes running in the background. Your training times might differ substantially from the training times shown in the example.
Fix Random Number Stream for Reproducibility
The example code might involve computation of random numbers at various stages. Fixing the random number stream at the beginning of various sections in the example code preserves the random number sequence in the section every time you run it, and increases the likelihood of reproducing the results. For more information, see Results Reproducibility.
Fix the random number stream with seed zero and random number algorithm Mersenne Twister. For more information on controlling the seed used for random number generation, see rng
.
previousRngState = rng(0,"twister")
previousRngState = struct with fields:
Type: 'twister'
Seed: 0
State: [625×1 uint32]
The output previousRngState
is a structure that contains information about the previous state of the stream. You will restore the state at the end of the example.
Discrete Action Space Cart-Pole MATLAB Environment
The reinforcement learning environment for this example is a pole attached to an unactuated joint on a cart, which moves along a frictionless track. The agent can apply a force on the cart and its training goal is to balance the pole upright using minimal control effort.
For this environment:
The upright pole angle is zero radians. Initially, the pole hangs downward (with an angle of
pi
radians) without moving.The pole starts balanced, upright with an initial angle between –0.05 and 0.05 radians.
The force action signal from the agent to the environment is either –10 or 10 N.
The observations from the environment are the position and velocity of the cart, the pole angle, and the pole angle derivative.
The episode terminates if the pole is more than 12 degrees from vertical or if the cart moves more than 2.4 m from the original position.
A reward of +1 is provided for every time step that the pole remains balanced, upright. A penalty of –5 is applied when the pole falls.
For more information on this model, see Load Predefined Control System Environments.
Create Environment Object
Create a predefined environment object for the system.
env = rlPredefinedEnv("CartPole-Discrete")
env = CartPoleDiscreteAction with properties: Gravity: 9.8000 MassCart: 1 MassPole: 0.1000 Length: 0.5000 MaxForce: 10 Ts: 0.0200 ThetaThresholdRadians: 0.2094 XThreshold: 2.4000 RewardForNotFalling: 1 PenaltyForFalling: -5 State: [4×1 double]
The object has a discrete action space where the agent can apply one of two possible force values to the cart, –10 or 10 N.
You can visualize the cart-pole system by using the plot
function during training or simulation.
plot(env)
Get the observation and action specification objects.
obsInfo = getObservationInfo(env)
obsInfo = rlNumericSpec with properties: LowerLimit: -Inf UpperLimit: Inf Name: "CartPole States" Description: "x, dx, theta, dtheta" Dimension: [4 1] DataType: "double"
actInfo = getActionInfo(env)
actInfo = rlFiniteSetSpec with properties: Elements: [-10 10] Name: "CartPole Action" Description: [0×0 string] Dimension: [1 1] DataType: "double"
Configure Training and Simulation Options for All Agents
Set up an evaluator object to evaluate the agent 10 times without exploration every 100 training episodes.
evl = rlEvaluator(NumEpisodes=10,EvaluationFrequency=100);
Create a training options object. For this example, use the following options:
Run the training for a maximum of 5000 episodes, with each episode lasting a maximum of 500 time steps.
Stop training when the average reward in the evaluation episodes is greater than 500. At this point, the agent can control the position of the pole using minimal control effort.
To have a better insight on the agent's behavior during training, plot the training progress (default option). If you want to achieve faster training times, set the
Plots
option tonone
.
trainOpts = rlTrainingOptions(... MaxEpisodes=5000, ... MaxStepsPerEpisode=500, ... StopTrainingCriteria="EvaluationStatistic",... StopTrainingValue=500);
For more information on training options, see rlTrainingOptions
.
To simulate the trained agent, create a simulation options object and configure it to simulate for 500 steps.
simOptions = rlSimulationOptions(MaxSteps=500);
For more information on simulation options, see rlSimulationOptions
.
Create, Train, and Simulate a DQN Agent
The constructor functions initialize the agent networks randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.
rng(0,"twister")
Create a default rlDQNAgent
object using the environment specification objects.
dqnAgent = rlDQNAgent(obsInfo,actInfo);
Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.
dqnAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3; dqnAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
Use a larger experience buffer to store more experiences, therefore decreasing the likelihood of catastrophic forgetting.
dqnAgent.AgentOptions.ExperienceBufferLength = 1e6;
Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train
. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining
to false
. To train the agent yourself, set doTraining
to true
.
doTraining =false; if doTraining % To avoid plotting in training, recreate the environment. env = rlPredefinedEnv("CartPole-Discrete"); % Train the agent. Record the training time. tic dqnTngRes = train(dqnAgent,env,trainOpts,Evaluator=evl); dqnTngTime = toc; % Extract the number of training episodes and the number of total steps. dqnTngEps = dqnTngRes.EpisodeIndex(end); dqnTngSteps = sum(dqnTngRes.TotalAgentSteps); % Uncomment to save the trained agent and the training metrics. % save("dcpBchDQNAgent.mat", ... % "dqnAgent","dqnTngEps","dqnTngSteps","dqnTngTime") else % Load the pretrained agent and results for the example. load("dcpBchDQNAgent.mat", ... "dqnAgent","dqnTngEps","dqnTngSteps","dqnTngTime") end
For the DQN Agent, the training converges to a solution after 3000 episodes. You can check the trained agent within the cart-pole environment.
Ensure reproducibility of the simulation by fixing the seed used for random number generation.
rng(0,"twister")
Visualize the environment.
plot(env)
Configure the agent to use a greedy policy (no exploration) in simulation.
dqnAgent.UseExplorationPolicy = false;
Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim
.
experience = sim(env,dqnAgent,simOptions);
dqnTotalRwd = sum(experience.Reward)
dqnTotalRwd = 500
The trained DQN agent stabilizes the pole on the cart.
Create, Train, and Simulate a PG Agent
The constructor functions initialize the agent networks randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.
rng(0,"twister")
Create a default rlPGAgent
object using the environment specification objects.
pgAgent = rlPGAgent(obsInfo,actInfo);
Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.
pgAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3; pgAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3; pgAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1; pgAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;
Set the entropy loss weight to increase exploration.
pgAgent.AgentOptions.EntropyLossWeight = 0.005;
Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train
. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining
to false
. To train the agent yourself, set doTraining
to true
.
doTraining =false; if doTraining % To avoid plotting in training, recreate the environment. env = rlPredefinedEnv("CartPole-Discrete"); % Train the agent. Record the training time. tic pgTngRes = train(pgAgent,env,trainOpts,Evaluator=evl); pgTngTime = toc; % Extract the number of training episodes and the number of total steps. pgTngEps = pgTngRes.EpisodeIndex(end); pgTngSteps = sum(pgTngRes.TotalAgentSteps); % Uncomment to save the trained agent and the training metrics. % save("dcpBchPGAgent.mat", ... % "pgAgent","pgTngEps","pgTngSteps","pgTngTime") else % Load the pretrained agent and results for the example. load("dcpBchPGAgent.mat", ... "pgAgent","pgTngEps","pgTngSteps","pgTngTime") end
For the PG Agent, the training converges to a solution after 100 episodes. You can check the trained agent within the cart-pole environment.
Ensure reproducibility of the simulation by fixing the seed used for random number generation.
rng(0,"twister")
Visualize the environment.
plot(env)
Configure the agent to use a greedy policy (no exploration) in simulation.
pgAgent.UseExplorationPolicy = false;
Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim
.
experience = sim(env,pgAgent,simOptions);
pgTotalRwd = sum(experience.Reward)
pgTotalRwd = 500
The trained PG agent stabilizes the pole on the cart.
Create, Train, and Simulate an AC Agent
The constructor functions initialize the agent networks randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.
rng(0,"twister")
Create a default rlACAgent
object using the environment specification objects.
acAgent = rlACAgent(obsInfo,actInfo);
Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.
acAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3; acAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3; acAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1; acAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;
Set the entropy loss weight to increase exploration.
acAgent.AgentOptions.EntropyLossWeight = 0.005;
Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train
. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining
to false
. To train the agent yourself, set doTraining
to true
.
doTraining =false; if doTraining % To avoid plotting in training, recreate the environment. env = rlPredefinedEnv("CartPole-Discrete"); % Train the agent. Save the final agent and training results. tic acTngRes = train(acAgent,env,trainOpts,Evaluator=evl); acTngTime = toc; % Extract the number of training episodes and the number of total steps. acTngEps = acTngRes.EpisodeIndex(end); acTngSteps = sum(acTngRes.TotalAgentSteps); % Uncomment to save the trained agent and the training metrics. % save("dcpBchACAgent.mat", ... % "acAgent","acTngEps","acTngSteps","acTngTime") else % Load the pretrained agent and results for the example. load("dcpBchACAgent.mat", ... "acAgent","acTngEps","acTngSteps","acTngTime") end
For the AC agent, the training converges to a solution after 100 episodes. You can check the trained agent within the cart-pole environment.
Ensure reproducibility of the simulation by fixing the seed used for random number generation.
rng(0,"twister")
Visualize the environment.
plot(env)
Configure the agent to use a greedy policy (no exploration) in simulation.
acAgent.UseExplorationPolicy = false;
Simulate the environment with the trained agent for 500 steps. For more information on agent simulation, see sim
.
experience = sim(env,acAgent,simOptions);
acTotalRwd = sum(experience.Reward)
acTotalRwd = 500
The trained AC agent does not stabilize the pole on the cart.
Create, Train, and Simulate a PPO Agent
The constructor functions initialize the agent networks randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.
rng(0,"twister")
Create a default rlPPOAgent
object using the environment specification objects.
ppoAgent = rlPPOAgent(obsInfo,actInfo);
Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.
ppoAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3; ppoAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3; ppoAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1; ppoAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;
Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train
. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining
to false
. To train the agent yourself, set doTraining
to true
.
doTraining =false; if doTraining % To avoid plotting in training, recreate the environment. env = rlPredefinedEnv("CartPole-Discrete"); % Train the agent. Record the training time. tic ppoTngRes = train(ppoAgent,env,trainOpts,Evaluator=evl); ppoTngTime = toc; % Extract the number of training episodes and the number of total steps. ppoTngEps = ppoTngRes.EpisodeIndex(end); ppoTngSteps = sum(ppoTngRes.TotalAgentSteps); % Uncomment to save the trained agent and the training metrics. % save("dcpBchPPOAgent.mat", ... % "ppoAgent","ppoTngEps","ppoTngSteps","ppoTngTime") else % Load the pretrained agent and results for the example. load("dcpBchPPOAgent.mat", ... "ppoAgent","ppoTngEps","ppoTngSteps","ppoTngTime") end
For the PPO Agent, the training converges to a solution after 200 episodes. You can check the trained agent within the cart-pole environment.
Ensure reproducibility of the simulation by fixing the seed used for random number generation.
rng(0,"twister")
Visualize the environment.
plot(env)
Configure the agent to use a greedy policy (no exploration) in simulation.
ppoAgent.UseExplorationPolicy = false;
Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim
.
experience = sim(env,ppoAgent,simOptions);
ppoTotalRwd = sum(experience.Reward)
ppoTotalRwd = 500
The trained PPO agent stabilizes the pole on the cart.
Create, Train, and Simulate a SAC Agent
The constructor functions initialize the agent networks randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.
rng(0,"twister")
Create a default rlACAgent
object using the environment specification objects.
sacAgent = rlSACAgent(obsInfo,actInfo);
Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.
sacAgent.AgentOptions.CriticOptimizerOptions(1).LearnRate = 1e-3; sacAgent.AgentOptions.CriticOptimizerOptions(2).LearnRate = 1e-3; sacAgent.AgentOptions.CriticOptimizerOptions(1).GradientThreshold = 1; sacAgent.AgentOptions.CriticOptimizerOptions(2).GradientThreshold = 1; sacAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3; sacAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;
Set the initial entropy weight and target entropy to increase exploration.
sacAgent.AgentOptions.EntropyWeightOptions.EntropyWeight = 5e-3; sacAgent.AgentOptions.EntropyWeightOptions.TargetEntropy = 5e-1;
Use a larger experience buffer to store more experiences, therefore decreasing the likelihood of catastrophic forgetting.
sacAgent.AgentOptions.ExperienceBufferLength = 1e6;
Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train
. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining
to false
. To train the agent yourself, set doTraining
to true
.
doTraining =false; if doTraining % To avoid plotting in training, recreate the environment. env = rlPredefinedEnv("CartPole-Discrete"); % Train the agent. Record the training time. tic sacTngRes = train(sacAgent,env,trainOpts,Evaluator=evl); sacTngTime = toc; % Extract the number of training episodes and the number of total steps. sacTngEps = sacTngRes.EpisodeIndex(end); sacTngSteps = sum(sacTngRes.TotalAgentSteps); % Uncomment to save the trained agent and the training metrics. % save("dcpBchSACAgent.mat", ... % "sacAgent","sacTngEps","sacTngSteps","sacTngTime") else % Load the pretrained agent and results for the example. load("dcpBchSACAgent.mat", ... "sacAgent","sacTngEps","sacTngSteps","sacTngTime") end
For the AC agent, the training converges to a solution after 2000 episodes. You can check the trained agent within the cart-pole environment.
Ensure reproducibility of the simulation by fixing the seed used for random number generation.
rng(0,"twister")
Visualize the environment.
plot(env)
Configure the agent to use a greedy policy (no exploration) in simulation.
sacAgent.UseExplorationPolicy = false;
Simulate the environment with the trained agent for 500 steps. For more information on agent simulation, see sim
.
experience = sim(env,sacAgent,simOptions);
sacTotalRwd = sum(experience.Reward)
sacTotalRwd = 500
The trained SAC agent stabilizes the pole on the cart.
Plot Training and Simulation Metrics
For each agent, collect the total reward from the final simulation episode, the number of training episodes, the total number of agent steps, and the total training time as shown in the Reinforcement Learning Training Monitor.
simReward = [ dqnTotalRwd pgTotalRwd acTotalRwd ppoTotalRwd sacTotalRwd ]; tngEpisodes = [ dqnTngEps pgTngEps acTngEps ppoTngEps sacTngEps ]; tngSteps = [ dqnTngSteps pgTngSteps acTngSteps ppoTngSteps sacTngSteps ]; tngTime = [ dqnTngTime pgTngTime acTngTime ppoTngTime sacTngTime ];
Plot the simulation reward, number of training episodes, number of training steps (that is the number of interactions between the agent and the environment) and the training time. Scale the data by the factor [10 50 5e6 50]
for better visualization.
bar([simReward,tngEpisodes,tngSteps,tngTime]./[10 50 5e6 50]) xticklabels(["DQN" "PG" "AC" "PPO" "SAC"]) legend("Total Reward","Training Episodes", ... "Training Steps","Training Time", ... "Location","north")
The plot shows that, for this environment, and with the used random number generator seed and initial conditions, all agents except AC are able to stabilize the pole on the cart in simulation. AC, PG and PPO use considerably less training time, while SAC (due to its more complex algorithm that needs to perform more gradient calculations), uses considerably more training time. With a different random seed, the initial agent networks would be different, and therefore, convergence results might be different. For more information on the relative strengths and weaknesses of each agent, see Reinforcement Learning Agents.
Save all the variables created in this example, including the training results, for later use.
% Uncomment to save all the workspace variables % save dcpAllVariables.mat
Restore the random number stream using the information stored in previousRngState
.
rng(previousRngState);