Main Content

Train Reinforcement Learning Agent for Simple Contextual Bandit Problem

This example shows how to solve a contextual bandit problem [1] using reinforcement learning by training DQN and Q agents. For more information on these agents, see Deep Q-Network (DQN) Agents and Q-Learning Agents.

In contextual bandit problems, an agent selects an action given the initial observation (context), it receives a reward, and the episode terminates. Hence, the agent action does not affect the next observation.

Contextual bandits can be used for various applications such as hyperparameter tuning, recommender systems, medical treatment, and 5G communication.

The following figure show how multi-armed bandits and contextual bandits are special cases of reinforcement learning.

In bandit problems, the environment has no dynamics, so the reward is only influenced by the current action and (for contextual bandits) the current observation (in these problems the observation is also referred to as context).

Neither rewards nor observations are influenced by any environment state (or by previous actions or observations), so the environment does not evolve along the time dimension, and there is no sequential decision making. The problem then become one of finding the action that maximizes the current reward (given a context, if present). Single-armed bandit problems are just special cases of multi-armed bandit problems in which the action is a scalar instead of a vector.

Environment

The contextual bandit environment in this example is defined as follows:

  • Observation (discrete): {1, 2}

The context (initial observation) is sampled randomly.

Pr(s=1)=0.5Pr(s=2)=0.5

  • Action (discrete): {1, 2, 3}

  • Reward:

Rewards in this environment are stochastic. The probability of each observation and action pair is defined as follows.

1.s=1,a=1Pr(r=5   |s=1,a=1)=0.3Pr(r=2|s=1,a=1)=0.72.s=1,a=2Pr(r=10|s=1,a=2)=0.1Pr(r=1  |s=1,a=2)=0.93.s=1,a=3Pr(r=3.5|s=1,a=3)=1

4.s=2,a=1Pr(r=10   |s=2,a=1)=0.2Pr(r=2|s=2,a=1)=0.85.s=2,a=2Pr(r=3|s=2,a=2)=16.s=2,a=3Pr(r=5|s=2,a=3)=0.5Pr(r=0.5|s=2,a=3)=0.5

Note that the agent does not know these distributions.

  • Is-Done signal: Since this is a contextual bandit problem, each episode has only one step. Hence, the Is-Done signal is always 1.

Create Environment Interface

Create the contextual bandit environment using ToyContextualBanditEnvironment, located in this example folder.

env = ToyContextualBanditEnvironment;
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Fix the random generator seed for reproducibility.

rng(1) 

Create DQN Agent

Create a DQN agent with a default network structure using rlAgentInitializationOptions.

agentOpts = rlDQNAgentOptions(...
    UseDoubleDQN = false, ...    
    TargetSmoothFactor = 1, ...
    TargetUpdateFrequency = 4, ...     
    MiniBatchSize = 64,...
    MaxMiniBatchPerEpoch=2);
agentOpts.EpsilonGreedyExploration.EpsilonDecay = 0.0005;

initOpts = rlAgentInitializationOptions(NumHiddenUnit = 16);

DQNagent = rlDQNAgent(obsInfo, actInfo, initOpts, agentOpts);

Train Agent

To train the agent, first specify the training options. For this example, use the following options:

  • Train for 3000 episodes.

  • Since this is a contextual bandit problem, and each episode has only one step, set MaxStepsPerEpisode to 1.

For more information, see rlTrainingOptions.

Train the agent using the train function. To save time while running this example, load a pre-trained agent by setting doTraining to false. To train the agent yourself, set doTraining to be true.

MaxEpisodes = 3000;
trainOpts = rlTrainingOptions(...
    MaxEpisodes = MaxEpisodes, ...
    MaxStepsPerEpisode = 1, ...
    Verbose = false, ...
    Plots = "training-progress",...
    StopTrainingCriteria = "None",...
    StopTrainingValue = "None"); 

doTraining = false;
if doTraining
    % Train the agent
    trainingStats = train(DQNagent,env,trainOpts);
else
    % Load the pre-trained agent for the example
    load("ToyContextualBanditDQNAgent.mat","DQNagent")
end

Validate DQN Agent

Assume that you know the distribution of the rewards, and you can compute the optimal actions. Validate the agent's performance by comparing these optimal actions with the actions selected by the agent. First, compute the true expected rewards with the true distributions.

1. The expected reward of each action at s=1 is as follows.

Ifa=1E[R]=0.3*5+0.7*2=2.9Ifa=2E[R]=0.1*10+0.9*1=1.9Ifa=3E[R]=3.5

Hence, the optimal action is 3 when s=1.

2. The expected reward of each action at s=2 is as follows.

Ifa=1E[R]=0.2*10+0.8*2=3.6Ifa=2E[R]=3.0If  a=3E[R]=0.5*5+0.5*0.5=2.75

Hence, the optimal action is 1 when s=2.

With enough sampling, the Q-values should be closer to the true expected reward. Visualize the true expected rewards.

ExpectedRewards = zeros(2,3);
ExpectedRewards(1,1) = 0.3*5 + 0.7*2;
ExpectedRewards(1,2) = 0.1*10 + 0.9*1;
ExpectedRewards(1,3) = 3.5;
ExpectedRewards(2,1) = 0.2*10 + 0.8*2;
ExpectedRewards(2,2) = 3.0;
ExpectedRewards(2,3) = 0.5*5 + 0.5*0.5;

localPlotQvalues(ExpectedRewards, "Expected Rewards")

Figure contains an axes object. The axes object with title Expected Rewards contains 7 objects of type image, text.

Now, validate whether the DQN agent learns the optimal behavior.

If the state is 1, the optimal action is 3.

observation = 1;
getAction(DQNagent,observation)
ans = 1×1 cell array
    {[3]}

The agent selects the optimal action.

If the state is 2, the optimal action is 1.

observation = 2;
getAction(DQNagent,observation)
ans = 1×1 cell array
    {[1]}

The agent selects the optimal action. Thus, the DQN agent has learned the optimal behavior.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

% Get critic
figure(1)
DQNcritic = getCritic(DQNagent);
QValues = zeros(2,3);
for s = 1:2
    QValues(s,:) = getValue(DQNcritic, {s});
end

% Visualize Q values
localPlotQvalues(QValues, "Q values")

Figure contains an axes object. The axes object with title Q values contains 7 objects of type image, text.

The learned Q-values are close to the true expected rewards computed above.

Create Q-learning Agent

Next, train a Q-learning agent. To create a Q-learning agent, first create a table using the observation and action specifications from the environment.

rng(1); % For reproducibility
Colormap does not exist
Colormap does not exist
qTable = rlTable(obsInfo, actInfo);
critic = rlQValueFunction(qTable, obsInfo, actInfo);

opt = rlQAgentOptions;
opt.EpsilonGreedyExploration.Epsilon = 1;
opt.EpsilonGreedyExploration.EpsilonDecay = 0.0005;

Qagent = rlQAgent(critic,opt);

Train Q-Learning Agent

To save time while running this example, load a pre-trained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(Qagent,env,trainOpts);
else
    % Load the pre-trained agent for the example.
    load("ToyContextualBanditQAgent.mat","Qagent")
end

Validate Q-Learning Agent

When the state is 1, the optimal action is 3.

observation = 1;
getAction(Qagent,observation)
ans = 1×1 cell array
    {[3]}

The agent selects the optimal action.

When the state is 2, the optimal action is 1.

observation = 2;
getAction(Qagent,observation)
ans = 1×1 cell array
    {[1]}

The agent selects the optimal action. Hence, the Q-learning agent has learned the optimal behavior.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

% Get critic
figure(2)
Qcritic = getCritic(Qagent);
QValues = zeros(2,3);
for s = 1:2
    for a = 1:3
        QValues(s,a) = getValue(Qcritic, {s}, {a});
    end
end

% Visualize Q values
localPlotQvalues(QValues, "Q values")

Figure contains an axes object. The axes object with title Q values contains 7 objects of type image, text.

Colormap does not exist

Again, the learned Q-values are close to the true expected rewards. The Q-values for deterministic rewards, Q(s=1, a=3) and Q(s=2, a=2), are the same as the true expected rewards. Note that the corresponding Q-values learned by the DQN network, while close, are not identical to the true values. This happens because the DQN uses a neural network instead of a table as the internal function approximator.

Local Function

function localPlotQvalues(QValues, titleText)
    % Visualize Q values 
    figure;
    imagesc(QValues,[1,4])
    colormap("autumn")
    title(titleText)
    colorbar
    set(gca,"Xtick",1:3,"XTickLabel",{"a=1", "a=2", "a=3"})
    set(gca,"Ytick",1:2,"YTickLabel",{"s=1", "s=2"})

    % Plot values on the image
    x = repmat(1:size(QValues,2), size(QValues,1), 1);
    y = repmat(1:size(QValues,1), size(QValues,2), 1)';
    QValuesStr = num2cell(QValues);
    QValuesStr = cellfun(@num2str, QValuesStr, UniformOutput=false);
    text(x(:), y(:), QValuesStr, HorizontalAlignment = "Center")
end

Reference

[1] Sutton, Richard S., and Andrew G. Barto. Reinforcement Learning: An Introduction. Second edition. Adaptive Computation and Machine Learning Series. Cambridge, Massachusetts: The MIT Press, 2018.

See Also

Apps

Functions

Objects

Related Examples

More About