Custom Agents

To implement your own custom reinforcement learning algorithms, you can create a custom agent by creating a subclass of a custom agent class. For more information about creating MATLAB® classes, see User-Defined Classes (MATLAB).

Create Template Class

To define your custom agent, first create a class that is a subclass of the rl.agent.CustomAgent class. As an example, this topic describes the custom LQR agent trained in Train Custom LQR Agent. As a starting point for your own agent, you can open and modify this custom agent class. To do so, add the example files to the path and open the file. In the MATLAB command window, type:

edit LQRCustomAgent.m

This class has the following class definition, which indicates the agent class name and the associated abstract agent.

classdef LQRCustomAgent < rl.agent.CustomAgent

To define your agent you must specify the following:

  • Agent properties

  • Constructor function

  • Critic representation that estimates the discounted long-term reward (if required for learning)

  • Actor representation that selects an action based on the current observation (if required for learning)

  • Required agent methods

  • Optional agent methods

Agent Properties

In the properties section of the class file, specify any parameters necessary for creating and training the agent. These parameters can include:

  • Discount factor for discounting future rewards

  • Configuration parameters for exploration models, such as noise models or epsilon-greedy exploration

  • Experience buffers for using replay memory

  • Mini-batch sizes for sampling from the experience buffer

  • Number of steps to look ahead during training

For more information on potential agent properties, see the option sets for the built-in Reinforcement Learning Toolbox™ agents.

The rl.Agent.CustomAgent class already includes properties for the agent sample time (SampleTime) and the action and observation specifications (ActionInfo and ObservationInfo, respectively).

The custom LQR agent defines the following agent properties.

properties
    % Q
    Q

    % R
    R

    % Feedback gain
    K

    % Discount Factor
    Gamma = 0.95

    % Critic
    Critic

    % Buffer for K
    KBuffer  
    % Number of updates for K
    KUpdate = 1

    % Number for estimator update
    EstimateNum = 10
end

properties (Access = private)
    Counter = 1
    YBuffer
    HBuffer 
end

Constructor Function

To create your custom agent, you must define a constructor function that:

  • Defines the action and observation specifications. For more information about creating these specifications, see rlNumericSpec and rlFiniteSetSpec.

  • Creates actor and critic representations as required by your training algorithm. For more information, see rlRepresentation.

  • Configures agent properties.

  • Calls the constructor of the base abstract class.

For example, the LQRCustomAgent constructor defines continuous action and observation spaces and creates a critic representation. The createCritic function is an optional helper function used for defining the critic representation.

function obj = LQRCustomAgent(Q,R,InitialK)
    % Check the number of input arguments
    narginchk(3,3);

    % Call the abstract class constructor
    obj = obj@rl.agent.CustomAgent();

    % Set the Q and R matrices
    obj.Q = Q;
    obj.R = R;

    % Define the observation and action spaces
    obj.ObservationInfo = rlNumericSpec([size(Q,1),1]);
    obj.ActionInfo = rlNumericSpec([size(R,1),1]);

    % Create the critic representation
    obj.Critic = createCritic(obj);

    % Initialize the gain matrix
    obj.K = InitialK;

    % Initialize the experience buffers
    obj.YBuffer = zeros(obj.EstimateNum,1);
    num = size(Q,1) + size(R,1);
    obj.HBuffer = zeros(obj.EstimateNum,0.5*num*(num+1));
    obj.KBuffer = cell(1,1000);
    obj.KBuffer{1} = obj.K;
end
end

Actor and Critic Representations

If your learning algorithm uses a critic representation to estimate the long term reward, an actor for selecting an action, or both, you must add these as agent properties. You must then create these representations when you create your agent; that is, in the constructor function. For more information on creating actors and critics, see Create Policy and Value Function Representations.

For example, the custom LQR agent uses a critic representation, stored in its Critic property, and no actor. The critic creation is implemented in the getCritic helper function, which is called from the LQRCustomAgent constructor.

function critic = createCritic(obj)
    nQ = size(obj.Q,1);
    nR = size(obj.R,1);
    n = nQ+nR;
    w0 = 0.1*ones(0.5*(n+1)*n,1);
    critic = rlRepresentation(@(x,u) computeQuadraticBasis(x,u,n),w0,...
        getObservationInfo(obj),getActionInfo(obj));
    critic.Options.GradientThreshold = 1;
    critic = critic.setLoss('mse');
end

In this case, the critic is an rlLinearBasisRepresentation object created using the rlRepresentation function. To create such a representation, you must specify the handle to a custom basis function, which in this case is the computeQuadraticBasis function. For more information on this critic representation, see Train Custom LQR Agent.

Required Functions

To create a custom reinforcement learning agent you must define the following implementation functions. To call these functions in your own code, use the wrapper methods from the abstract base class. For example, to call getActionImpl, use getAction. The wrapper methods have the same input and output arguments as the implementation methods.

FunctionDescription
getActionImplSelects an action by evaluating the agent policy for a given observation.
getActionWithExplorationImplSelects an action using the exploration model of the agent
learnImplLearns from the current experiences and returns an action with exploration

Within your implementation functions, to evaluate your actor and critic representations, you can use the evaluate function. To evaluate:

  • A critic with only observation input signals, obtain the value function V using:

    V = evaluate(this.Critic,Observation);
  • An critic with both observation and action input signals, obtain the Q function Q using:

    Q = evaluate(this.Critic,[Observation,Action]);
  • An actor with a continuous action space, obtain the action value A using:

    A = evaluate(this.Actor,Observation);

    If the actor has a continuous action space, A contains the values of the action signals. If the actor has a discrete action space, A contains the probability of taking each action.

For the preceding syntaxes commands, Observation and Action are cell arrays.

getActionImpl Function

The getActionImpl function is used to evaluate the policy of your agent and select an action. This function must have the following signature, where obj is the agent object, Observation is the current observation, and action is the selected action.

function action = getActionImpl(obj,Observation)

For the custom LQR agent, you select an action by applying the u=-Kx control law.

function action = getActionImpl(obj,Observation)
    % Given the current state of the system, return an action.
    action = -obj.K*Observation{:};
end

getActionWithExplorationImpl Function

The getActionWithExplorationImpl function selects an action using the exploration model of your agent. Using this function you can implement algorithms such as epsilon-greedy exploration. This function must have the following signature, where obj is the agent object, Observation is the current observation, and action is the selected action.

function action = getActionWithExplorationImpl(obj,Observation)

For the custom LQR agent, the getActionWithExplorationImpl function adds random white noise to an action selected using the current agent policy.

function action = getActionWithExplorationImpl(obj,Observation)
    % Given the current observation, select an action
    action = getAction(obj,Observation);
    
    % Add random noise to the action
    num = size(obj.R,1);
    action = action + 0.1*randn(num,1);
end

learnImpl Function

The learnImpl function defines how the agent learns from the current experience. This function implements the custom learning algorithm of your agent by updating the policy parameters and selecting an action with exploration. This function must have the following signature, where obj is the agent object, exp is the current agent experience, and action is the selected action.

function action = learnImpl(obj,exp)

The agent experience, is the cell array exp = {state,action,reward,nextstate,isdone}. Here:

  • state is the current observation.

  • action is the current action.

  • reward is the current reward.

  • nextState is the next observation.

  • isDone is a logical flag indicating that the training episode is complete.

For the custom LQR agent, the critic parameters are updated every N steps.

function action = learnImpl(obj,exp)
    % Parse the experience input
    x = exp{1}{1};
    u = exp{2}{1};
    dx = exp{4}{1};            
    y = (x'*obj.Q*x + u'*obj.R*u);
    num = size(obj.Q,1) + size(obj.R,1);

    % Wait N steps before updating critic parameters
    N = obj.EstimateNum;
    h1 = computeQuadraticBasis(x,u,num);
    h2 = computeQuadraticBasis(dx,-obj.K*dx,num);
    H = h1 - obj.Gamma* h2;
    if obj.Counter<=N
        obj.YBuffer(obj.Counter) = y;
        obj.HBuffer(obj.Counter,:) = H;
        obj.Counter = obj.Counter + 1;
    else
        % Update the critic parameters based on the batch of
        % experiences
        H_buf = obj.HBuffer;
        y_buf = obj.YBuffer;
        theta = (H_buf'*H_buf)\H_buf'*y_buf;
        setLearnableParameterValues(obj.Critic,{theta});
        
        % Derive a new gain matrix based on the new critic parameters
        obj.K = getNewK(obj);
        
        % Reset the experience buffers
        obj.Counter = 1;
        obj.YBuffer = zeros(N,1);
        obj.HBuffer = zeros(N,0.5*num*(num+1));    
        obj.KUpdate = obj.KUpdate + 1;
        obj.KBuffer{obj.KUpdate} = obj.K;
    end

    % Find and return an action with exploration
    action = getActionWithExploration(obj,exp{4});
end

Optional Functions

Optionally, you can define how your agent is reset at the start of training by specifying a resetImpl function with the following function signature, where obj is the agent object. Using this function, you can set the agent into a know or random condition before training.

function resetImpl(ob)

Also, you can define any other helper functions in your custom agent class as required. For example, the custom LQR agent defines a createCritic function for creating the critic representation and a getNewK function that derives the feedback gain matrix from the trained critic parameters.

Create Custom Agent

Once you have defined your custom agent class, create an instance of it in the MATLAB workspace. For example, to create the custom LQR agent, define the Q, R, InitialK values, and call the constructor function.

Q = [10,3,1;3,5,4;1,4,9]; 
R = 0.5*eye(3);
K0 = place(A,B,[0.4,0.8,0.5]);
agent = LQRCustomAgent(Q,R,K0);

After validating the environment object, you can use it to train a reinforcement learning agent. For an example that trains the custom LQR agent, see Train Custom LQR Agent.

See Also

Related Topics