How to speed up vectorized operations for dynamic programming

6 visualizzazioni (ultimi 30 giorni)
I would like to speed up the following code which solves a discrete dynamic programming problem using the method of successive approximations, as described e.g. in Bertsekas.
The algorithm is made of two steps. In step 1, I precompute the payoff array R(a',a,z) where a' is the action and (a,z) are the states. In step 2 I compute the value function using the method of successive approximations: I guess V0, then I compute an updated V1 and finally I check if ||V1-V0|| is less than a tolerance level. If it is, I stop, otherwise I set V0=V1 and go on.
I profiled the code (see below a MWE) and the two most time-consuming lines are the following ones:
(1) RHS = Ret+beta*permute(EV,[1,3,2]);
(2) [max_val,max_ind] = max(RHS,[],1);
Line (1) takes up 63% of the total running time, line (2) takes 32%. As you can see I have already vectorized all loops.
I would be very grateful for any suggestion. I post below a MWE. (Note that I set n_a, the num of grid points, to a low value on purpose, to allow interested users to run quickly the example. In my actual code, n_a=10000 or more).
%% Solve income fluctuation problem CPU
clear;clc;close all
%% Economic parameters
sigma = 2;
r = 0.03;
beta = 0.96;
PZ = [0.60 0.40;
0.05 0.95];
z_grid = [0.5 1.0]';
n_z = length(z_grid);
b = 0; %lower bound for asset holdings a
grid_max = 4;
n_a = 500; % IN PRACTICE THIS IS EQUAL TO 5000-10000
R = 1+r;
a_grid = linspace(-b,grid_max,n_a)';
if sigma==1
fun_u = @(c) log(c);
else
fun_u = @(c) c.^(1-sigma)/(1-sigma);
end
%% Computational parameters
verbose = 0;
tiny = 1e-8; %very small positive number
tol = 1e-6; %tolerance for VFI and TI
max_iter = 500; %maximum num. of iterations for both VFI and TI
%% Start timing
tic
%% STEP 1- Precompute current payoff array R(a',a,z)
a_tomorrow = a_grid; %(a',1,1)
a_today = a_grid'; %(1,a,1)
z_today = shiftdim(z_grid,-2); %(1,1,z)
cons = (1+r)*a_today+z_today-a_tomorrow;
Ret = fun_u(cons); %size: [n_a,n_a,n_z]
Ret(cons<=0) = -inf;
%% STEP 2 - Value function iteration
iter = 1;
err = tol+1;
V0 = zeros(n_a,n_z);
while err>tol && iter<=max_iter
EV = V0*PZ'; %(a',z)
RHS = Ret+beta*permute(EV,[1,3,2]);
[max_val,max_ind] = max(RHS,[],1);
V1 = squeeze(max_val);
pol_ind_ap = squeeze(max_ind);
err = max(abs(V0(:)-V1(:)));
if verbose==1
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = V1;
end
if err>tol
error('VFI did not converge!')
else
fprintf('VFI converged after = %d iterations \n',iter)
end
pol_ap = a_grid(pol_ind_ap);
pol_c = (1+r)*a_grid+z_grid'-pol_ap;
%% End timing
toc
%% Figures
figure
plot(a_grid,pol_c(:,1),'linewidth',2)
hold on
plot(a_grid,pol_c(:,2),'linewidth',2)
legend('Low shock','High shock','Location','NorthWest')
xlabel('asset level')
ylabel('consumption')
title('Consumption Policy Function')
figure
plot(a_grid,a_grid,'--','linewidth',2)
hold on
plot(a_grid,pol_ap(:,1),'linewidth',2)
hold on
plot(a_grid,pol_ap(:,2),'linewidth',2)
legend('45 line','Low shock','High shock','Location','NorthWest')
xlabel('Current period assets')
ylabel('Next-period assets')
title('Assets Policy Function')
  2 Commenti
Torsten
Torsten il 14 Set 2024
Modificato: Torsten il 14 Set 2024
You want to speed up a runtime of 0.17 s ?
Apart from this: I don't think there is much to optimize in the two commands you listed.
Alessandro
Alessandro il 14 Set 2024
The running time of the code depends on the number of grid points for assets, n_a. In this example I set n_a=500 but I need something like n_a=10000. I wrote this clearly as a comment to the line where I set n_a.
Moreover, the problem shown here is part of a larger project and it has to run hundreds of times, so EVEN if it was 0.17s it would still be worth to speed up

Accedi per commentare.

Risposta accettata

Matt J
Matt J il 14 Set 2024
Modificato: Matt J il 14 Set 2024
This might be a little faster.
betaPZtransp=beta*PZ';
tic
while err>tol && iter<=max_iter
RHS = Ret + reshape(V0*betaPZtransp,n_a,1,n_z);
V1 = max(RHS,[],1);
err = norm( V0(:)-V1(:) ,inf);
if verbose
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = reshape(V1,n_a,n_z);
end
toc
[V1,pol_ind_ap]=max(RHS,[],1);
pol_ind_ap = reshape(pol_ind_ap, n_a,n_z);
  2 Commenti
Matt J
Matt J il 14 Set 2024
Modificato: Matt J il 14 Set 2024
You should also use gpuArrays if you have appropriate GPU hardware and toolboxes.
Alessandro
Alessandro il 14 Set 2024
Modificato: Alessandro il 14 Set 2024
Thanks! Your version is indeed a bit faster. After doing some tests, what really helped is to move the computation of the argmax (i.e. pol_ind_ap) out of the while loop and eliminating the call to squeeze. Instead, replacing permute with reshape does not seem to affect the timing (maybe because both permute and reshape are built-in functions, while squeeze is an M-file.
I will try the gpu later.

Accedi per commentare.

Più risposte (0)

Categorie

Scopri di più su Creating and Concatenating Matrices in Help Center e File Exchange

Prodotti


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by