Main Content

rlMBPOAgent

Model-based policy optimization reinforcement learning agent

    Description

    A model-based policy optimization (MBPO) agent is a model-based, online, off-policy, reinforcement learning method. An MBPO agent contains an internal model of the environment, which it uses to generate additional experiences without interacting with the environment.

    During training, the MBPO agent generates real experiences by interacting with the environment. These experiences are used to train the internal environment model, which is used to generate additional experiences. The training algorithm then uses both the real and generated experiences to update the agent policy.

    Creation

    Description

    example

    agent = rlMBPOAgent(baseAgent,envModel) creates a model-based policy optimization agent with default options and sets the BaseAgent and EnvModel properties.

    agent = rlMBPOAgent(___,agentOptions) creates a model-based policy optimization agent using specified options and sets the AgentOptions property.

    Properties

    expand all

    Base reinforcement learning agent, specified as an off-policy agent object.

    For environments with a discrete action space, specify a DQN agent using an rlDQNAgent object.

    For environments with a continuous action space, use one of the following agent objects.

    Environment model, specified as an rlNeuralNetworkEnvironment object. This environment contains transition functions, a reward function, and an is-done function.

    Agent options, specified as an rlMBPOAgentOptions object.

    Current roll-out horizon value, specified as a positive integer. For more information on setting the initial horizon value and the horizon update method, see rlMBPOAgentOptions.

    Model experience buffer, specified as an rlReplayMemory object. During training the agent stores each of its generated experiences (S,A,R,S',D) in a buffer. Here:

    • S is the current observation of the environment.

    • A is the action taken by the agent.

    • R is the reward for taking action A.

    • S' is the next observation after taking action A.

    • D is the is-done signal after taking action A.

    Option to use exploration policy when selecting actions, specified as one of the following logical values.

    • true — Use the base agent exploration policy when selecting actions.

    • false — Use the base agent greedy policy when selecting actions.

    The initial value of UseExplorationPolicy matches the value specified in BaseAgent. If you change the value of UseExplorationPolicy in either the base agent or the MBPO agent, the same value is used for the other agent.

    This property is read-only.

    Observation specifications, specified as a reinforcement learning specification object or an array of specification objects defining properties such as dimensions, data type, and names of the observation signals.

    The value of ObservationInfo matches the corresponding value specified in BaseAgent.

    This property is read-only.

    Action specification, specified as a reinforcement learning specification object or an array of specification objects defining properties such as dimensions, data type, and names of the action signals.

    The value of ActionInfo matches the corresponding value specified in BaseAgent.

    Sample time of agent, specified as a positive scalar or as -1. Setting this parameter to -1 allows for event-based simulations.

    The initial value of SampleTime matches the value specified in BaseAgent. If you change the value of SampleTime in either the base agent or the MBPO agent, the same value is used for the other agent.

    Within a Simulink® environment, the RL Agent block in which the agent is specified to execute every SampleTime seconds of simulation time. If SampleTime is -1, the block inherits the sample time from its parent subsystem.

    Within a MATLAB® environment, the agent is executed every time the environment advances. In this case, SampleTime is the time interval between consecutive elements in the output experience returned by sim or train. If SampleTime is -1, the time interval between consecutive elements in the returned output experience reflects the timing of the event that triggers the agent execution.

    Object Functions

    trainTrain reinforcement learning agents within a specified environment
    simSimulate trained reinforcement learning agents within specified environment

    Examples

    collapse all

    Create an environment interface and extract observation and action specifications.

    env = rlPredefinedEnv("CartPole-Continuous");
    obsInfo = getObservationInfo(env);
    actInfo = getActionInfo(env);

    Create a base off-policy agent. For this example, use a SAC agent.

    agentOpts = rlSACAgentOptions;
    agentOpts.MiniBatchSize = 256;
    initOpts = rlAgentInitializationOptions(NumHiddenUnit=64);
    baseagent = rlSACAgent(obsInfo,actInfo,initOpts,agentOpts);

    Check your agent with a random input observation.

    getAction(baseagent,{rand(obsInfo.Dimension)})
    ans = 1x1 cell array
        {[-7.2875]}
    
    

    The neural network environment uses a function approximator object to approximate the environment transition function. The function approximator object uses one or more neural networks as approximator model. To account for modeling uncertainty, you can specify multiple transition models. For this example, create a single transition model.

    Create a neural network to use as approximation model within the transition function object. Define each network path as an array of layer objects. Specify a name for the input and output layers, so you can later explicitly associate them with the appropriate channel.

    % Observation and action paths
    obsPath = featureInputLayer(obsInfo.Dimension(1),Name="obsIn");
    actionPath = featureInputLayer(actInfo.Dimension(1),Name="actIn");
    
    % Common path: concatenate along dimension 1
    commonPath = [concatenationLayer(1,2,Name="concat")
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(obsInfo.Dimension(1),Name="nextObsOut")];
    
    % Add layers to layerGraph object
    transNet = layerGraph(obsPath);
    transNet = addLayers(transNet,actionPath);
    transNet = addLayers(transNet,commonPath);
    
    % Connect layers
    transNet = connectLayers(transNet,"obsIn","concat/in1");
    transNet = connectLayers(transNet,"actIn","concat/in2");
    
    % Convert to dlnetwork object
    transNet = dlnetwork(transNet);
    
    % Display number of weights
    summary(transNet)
       Initialized: true
    
       Number of learnables: 4.8k
    
       Inputs:
          1   'obsIn'   4 features
          2   'actIn'   1 features
    

    Create the transition function approximator object.

    transitionFcnAppx = rlContinuousDeterministicTransitionFunction( ...
        transNet,obsInfo,actInfo,...
        ObservationInputNames="obsIn",...
        ActionInputNames="actIn",...
        NextObservationOutputNames="nextObsOut");

    Create a neural network to use as a reward model for the reward function approximator object.

    % Observation and action paths
    actionPath = featureInputLayer(actInfo.Dimension(1),Name="actIn");
    nextObsPath = featureInputLayer(obsInfo.Dimension(1),Name="nextObsIn");
    
    % Common path: concatenate along dimension 1
    commonPath = [concatenationLayer(1,2,Name="concat")
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(1)];
    
    % Add layers to layerGraph object
    rewardNet = layerGraph(nextObsPath);
    rewardNet = addLayers(rewardNet,actionPath);
    rewardNet = addLayers(rewardNet,commonPath);
    
    % Connect layers
    rewardNet = connectLayers(rewardNet,"nextObsIn","concat/in1");
    rewardNet = connectLayers(rewardNet,"actIn","concat/in2");
    
    % Convert to dlnetwork object
    rewardNet = dlnetwork(rewardNet);
    
    % Display number of weights
    summary(transNet)
       Initialized: true
    
       Number of learnables: 4.8k
    
       Inputs:
          1   'obsIn'   4 features
          2   'actIn'   1 features
    

    Create the reward function approximator object.

    rewardFcnAppx = rlContinuousDeterministicRewardFunction( ...
        rewardNet,obsInfo,actInfo, ...
        ActionInputNames="actIn",...
        NextObservationInputNames="nextObsIn");

    Create an is-done model for the reward function approximator object.

    % Define main path
    net = [featureInputLayer(obsInfo.Dimension(1),Name="nextObsIn");
    fullyConnectedLayer(64)
    reluLayer
    fullyConnectedLayer(64)
    reluLayer
    fullyConnectedLayer(2)
    softmaxLayer(Name="isdoneOut")];
    
    % Convert to layergraph object
    isDoneNet = layerGraph(net);
    
    % Convert to dlnetwork object
    isDoneNet = dlnetwork(isDoneNet);
    
    % Display number of weights
    summary(transNet)
       Initialized: true
    
       Number of learnables: 4.8k
    
       Inputs:
          1   'obsIn'   4 features
          2   'actIn'   1 features
    

    Create the reward function approximator object.

    isdoneFcnAppx = rlIsDoneFunction(isDoneNet,obsInfo,actInfo, ...
        NextObservationInputNames="nextObsIn");

    Create the neural network environment using the observation and action specifications and the three function approximator objects.

    generativeEnv = rlNeuralNetworkEnvironment( ...
        obsInfo,actInfo,...
        transitionFcnAppx,rewardFcnAppx,isdoneFcnAppx);

    Specify options for creating an MBPO agent. Specify the optimizer options for the transition network and use default values for all other options.

    MBPOAgentOpts = rlMBPOAgentOptions;
    MBPOAgentOpts.TransitionOptimizerOptions = rlOptimizerOptions(...
        LearnRate=1e-4,...
        GradientThreshold=1.0);

    Create the MBPO agent.

    agent = rlMBPOAgent(baseagent,generativeEnv,MBPOAgentOpts);

    Check your agent with a random input observation.

    getAction(agent,{rand(obsInfo.Dimension)})
    ans = 1x1 cell array
        {[7.8658]}
    
    

    Version History

    Introduced in R2022a