Can the reset function of the environment for deep reinforcement learning incorporate feedback from the agent's state and reward?

10 visualizzazioni (ultimi 30 giorni)
Currently, I am designing a control system using deep reinforcement learning (DDPG) in reinforcement learning toolbox, MATLAB/Simulink. Since there are uncertainties in the parameters of the plant (for example, mass), I apply domain randomization to improve the agent’s robustness to such uncertainties. Specifically, I randomly vary the mass of the plant at the beginning of each training episode. This can be implemented within the environment's reset function.
At present, the range of parameter variation is fixed, and the mass is randomly sampled from a uniform distribution within this range. In the future, I would like to adaptively adjust this range of variation between episodes. More concretely, I aim to design an adaptation algorithm in which the range of variation for a given episode is determined based on the episode reward (or its moving average) achieved by the agent in the previous episode.
For example, if the episode reward is below a certain threshold, the mass variation range could be ±10%; if it is above the threshold, the range could be increased to ±20% or more.
However, I am not sure how to feed back the agent's performance (e.g., episode reward) or state (e.g., control output) into the reset function. How can I implement this kind of adaptive domain randomization in MATLAB/Simulink? I am currently using the train function for training the reinforcement learning agent. Under this standard training framework, is it possible to implement the aforementioned adaptive domain randomization?
obsInfo = rlNumericSpec([6 1]);
obsInfo.Name = "observations";
actInfo = rlNumericSpec([1 1]);
actInfo.Name = "control input";
mdl ='SIM_RL'; % Plant + RL agent block
env = rlSimulinkEnv( ...
"SIM_RL", ...
"SIM_RL/Agent/RL Agent", ...
obsInfo, actInfo);
% Domain randomization: Reset function
env.ResetFcn = @(in)localResetFcn(in);
function in = localResetFcn(in)
% Fixed range of plant parameter
M_min = Nominal_value*(1 - 0.5); % -50% of nominal mass
M_max = Nominal_value*(1 + 0.5); % +50% of nominal mass
% <------We would like to adaptively change the variation range.
% Randomize mass
randomValue_M = M_min + (M_max - M_min) * rand;
in = setBlockParameter(in, ...
"SIM_RL/Plant/Mass", ...
Value=num2str(randomValue_M));
end
% The construction of the critic Network is omitted here.
% ....
criticNet = initialize(criticNet);
critic = rlQValueFunction(criticNet,obsInfo,actInfo);
% The construction of the actor Network is omitted here.
% ....
actorNet = initialize(actorNet);
actor = rlContinuousDeterministicActor(actorNet,obsInfo,actInfo);
% agent
criticOpts = rlOptimizerOptions(LearnRate=1e-04,GradientThreshold=1);
actorOpts = rlOptimizerOptions(LearnRate=1e-04,GradientThreshold=1);
agentOpts = rlDDPGAgentOptions(...
SampleTime=0.01,...
CriticOptimizerOptions=criticOpts,...
ActorOptimizerOptions=actorOpts,...
ExperienceBufferLength=1e5,...
DiscountFactor=0.99,...
MiniBatchSize=128,...
TargetSmoothFactor=1e-3);
agent = rlDDPGAgent(actor,critic,agentOpts);
maxepisodes = 5000;
maxsteps = ceil(Simulation_End_Time/0.01);
trainOpts = rlTrainingOptions(...
MaxEpisodes=maxepisodes,...
MaxStepsPerEpisode=maxsteps,...
ScoreAveragingWindowLength=5,...
Verbose=true,...
Plots="training-progress",...
StopTrainingCriteria="EpisodeCount",...
SaveAgentCriteria="EpisodeReward",...
SaveAgentValue=-1.0);
doTraining = true;
if doTraining
evaluator = rlEvaluator(...
NumEpisodes=1,...
EvaluationFrequency=5);
% Train the agent.
trainingStats = train(agent,env,trainOpts,Evaluator=evaluator);
else
% Load the pretrained agent
load("agent.mat","agent")
end

Risposte (1)

Nithin
Nithin il 11 Giu 2025
Modificato: Nithin il 11 Giu 2025
Hi @平成,
The "Reinforcement Learning Toolbox" does not natively provide a mechanism for passing data from one episode to the next through the environment's "ResetFcn" since the "ResetFcn" is stateless by default. A workaround would be to implement the adaptive domain randomization by using a persistent variable or a global stateful mechanism outside the environment that stores the most recent reward and adjusts parameters accordingly. This is achieved by storing the reward of the previous episode and using it inside your "ResetFcn" to modify the mass range.
Additionally, a better approach would be to create a custom class that tracks performance between episodes. Refer to the following steps to understand the general approach:
  • Create a Custom Tracker Class
classdef DomainRandomizationManager < handle
properties
rewardHistory = [];
NominalValue = 10; % example nominal mass
end
methods
function updateReward(obj, reward)
obj.rewardHistory(end+1) = reward;
end
function [Mmin, Mmax] = getRandomizationRange(obj)
if isempty(obj.rewardHistory)
rangePercent = 0.1;
else
avgReward = mean(obj.rewardHistory(max(end-4,1):end)); % moving average
if avgReward < 100
rangePercent = 0.1;
else
rangePercent = 0.2;
end
end
Mmin = obj.NominalValue * (1 - rangePercent);
Mmax = obj.NominalValue * (1 + rangePercent);
end
end
end
  • Use it in the Reset Function
  • Hook into the Training Loop. MATLAB’s built-in "train" function does not allow easy feedback between episodes. So, for full flexibility, switch to a custom training loop:
global DRM;
DRM = DomainRandomizationManager();
for ep = 1:maxepisodes
experience = [];
totalReward = 0;
% reset env — ResetFcn uses DRM internally
initialObs = reset(env);
for t = 1:maxsteps
action = getAction(agent, initialObs);
[nextObs, reward, isDone, ~] = step(env, action);
totalReward = totalReward + reward;
experience = [experience; {initialObs, action, reward, nextObs}];
initialObs = nextObs;
if isDone
break;
end
end
DRM.updateReward(totalReward); % update the reward history
% Train the agent using collected experience
agent = train(agent, experience); % or use rlReplayBuffer if needed
end
Refer to the following MATLAB documentation for more information:

Categorie

Scopri di più su Training and Simulation in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by