Eliminate for-loop in recursive computation?

1 visualizzazione (ultimi 30 giorni)
I'm trying to speed up a recursive calculation that is currently using a for-loop. A minimum working example is below. There are two computations in the recursion: "pn0" and "fact". I have figured out how to do fact out side the loop, but pn0 is giving me trouble. Any insight into how to remove the loop and/or speed this up would be greatly appreaciated! I have tried building up a separate function and using arrayfun, but that is calling a loop under the hood, so I don't think it would be faster.
% define some constants
m = 0:10;
y = cos(pi/4);
% compute fact without a loop
fact = cumsum([1 ones(1,size(m,2)).*2]);
fact = fact(1:end-1);
% Generate the correct pn0 array
pn0 = 1; % initial value
factdum = 1; % initial value
pn_out = [];
fact_out = [];
for mm = 0:10
pn0 = -pn0*factdum*y; % do the computation <-- this is the computation I'm trying to pull out of the loop
pn_out = [pn_out,pn0]; % store the pn0 array output
fact_out = [fact_out,factdum]; % store the fact array output
factdum = factdum + 2; % compute fact inside the loop to make sure it was done correctly outside the loop.
end

Risposta accettata

Torsten
Torsten il 28 Gen 2025
Modificato: Torsten il 28 Gen 2025
pn_out = cumprod(fact).*(-y).^(m+1)
But I'm not sure this will be faster than your loop, especially y.^(1:n) compared to a recursive computation of the powers.
  1 Commento
Christopher Smith
Christopher Smith il 29 Gen 2025
Thank you for that clever solution! For timing I made some functions for use in the timeit function, and it turns out that the cumprod version that you suggested is actually faster. Using tic/toc though suggests that the cumprod version is slower. See the code below.
% define some constants
m = 0:10;
y = cos(pi/4);
fact = cumsum([1 ones(1,size(m,2)).*2]);
fact = fact(1:end-1);
% cumprod version by Torsten
pn = cumprod(fact).*(-y).^(m+1);
% Generate the correct pn array
pn0 = 1;
factdum = 1;
pn_out = [];
fact_out = [];
for mm = 0:10
pn0 = -pn0*factdum*y;
pn_out = [pn_out,pn0];
fact_out = [fact_out,factdum];
factdum = factdum + 2;
end
% check to make sure no mistakes were made in the conversion to functions
[f1,pn1] = pn_test1(m,y); % cumprod version
[f2,pn2] = pn_test2(m,y); % for-loop version
% perform the timing
ff1 = @() pn_test1(m,y);
test1_time = timeit(ff1) % cumprod version
test1_time = 1.7310e-05
ff2 = @() pn_test2(m,y);
test2_time = timeit(ff2) % for-loop version
test2_time = 1.3310e-05
% helper function for timing the the cumprod version
function [fact,pn] = pn_test1(m,y)
fact = cumsum([1 ones(1,size(m,2)).*2]);
fact = fact(1:end-1);
pn = cumprod(fact).*(-y).^(m+1);
end
% helper function for timing the for-loop version
function [fact,pn] = pn_test2(m,y)
pn0 = 1;
factdum = 1;
pn = [];
fact = [];
for mm = m
pn0 = -pn0*factdum*y;
pn = [pn,pn0];
fact = [fact,factdum];
factdum = factdum + 2;
end
end

Accedi per commentare.

Più risposte (0)

Categorie

Scopri di più su Linear Algebra in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by