How to multiply N matrices without a FOR loop? (Slices of 3D array)
27 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
I have a 3D matrix 2x2xN which, for my purposes, are essentially N 2x2 matrices and I want to do matrix multiplication with all of them so that I would get the following result:
N = 14;
M = rand(2,2,N);
Z = M(:,:,1)*M(:,:,2)* ... *M(:,:,N);
size(Z) == [2 2]
I can do it with a for loop, but I am looking for a single line approach, something like:
prod(M,3);
but probably with mtimes that would do matrix multiplication along the 3rd dimension (not the element-wise product).
I also converted matrix M into a Nx1 cell array of 2x2 matrices, but this approach did not work either to do the multiplication.
8 Commenti
Jan
il 7 Dic 2017
Modificato: Jan
il 7 Dic 2017
Stephen's comment is very good.
For the estimation of the effects of optimizing the code, the usual sizes of the inputs matter: Is it really a [2 x 2 x N] array and what sizes of N do you have? For larger rows and columns, the main is done by mtimes, while the loop does not matter much. mtimes calls optimized BLAS or ATLAS functions, such that there is no room for further improvements. But I do not know, if these library function handle tiny 2x2 matrices with unrolled loops. So perhaps a C-Mex function could be more efficient.
Risposte (5)
Jan
il 7 Dic 2017
Modificato: Jan
il 7 Dic 2017
If you really have 2x2 sub matrices to accumulate, try a C-Mex function:
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
const mwSize *size;
mwSize N;
double *p, *q, q11, q12, q21, q22, t11, t21;
p = mxGetPr(prhs[0]);
size = mxGetDimensions(prhs[0]);
if (size[0] != 2 || size[1] != 2) {
mexErrMsgIdAndTxt("JSimon:CumMProd2x2:BadInput1",
"1st input must be a [2 x 2 x N] array.");
}
N = size[2];
q11 = p[0];
q21 = p[1];
q12 = p[2];
q22 = p[3];
while (--N) { // Unrolled 2x2 matrix multiplication
p += 4;
t11 = q11 * p[0] + q12 * p[1];
t21 = q21 * p[0] + q22 * p[1];
q12 = q11 * p[2] + q12 * p[3];
q22 = q21 * p[2] + q22 * p[3];
q11 = t11;
q21 = t21;
}
plhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
q = mxGetPr(plhs[0]);
q[0] = q11;
q[1] = q21;
q[2] = q12;
q[3] = q22;
return;
}
[EDITED] This is tested now. The speed is very interesting:
function speed
x = rand(2, 2, 1000);
tic; for k = 1:1000, y = CumMProd2x2(x); end; toc
tic; for k = 1:1000, y = CumMProd2x2_AB(x); end; toc
tic
for k = 1:1000 % Jos (10584)
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf, x, size(x, 3));
end
toc
end
function out = CumMProd2x2_AB(M) % Andrei Bobrov
s = size(M, 3);
out = M(:,:,1);
for ii = 2:s
out = out * M(:,:,ii);
end
end
R2016b/64/Win7:
Elapsed time is 0.011403 seconds. C-mex
Elapsed time is 3.884977 seconds. Loop
Elapsed time is 96.038754 seconds. Recursive anonymous function
I was surprised, that Andrei's loop is such slow, although it is clearly the nicest and cleaned solution. Let's try to unroll the loops like in the C-Code:
function out = CumMProd2x2_unroll(M)
q11 = M(1);
q21 = M(2);
q12 = M(3);
q22 = M(4);
c = 1;
for ii = 2:size(M, 3)
c = c + 4;
t11 = q11 * M(c) + q12 * M(c+1);
t21 = q21 * M(c) + q22 * M(c+1);
q12 = q11 * M(c+2) + q12 * M(c+3);
q22 = q21 * M(c+2) + q22 * M(c+3);
q11 = t11;
q21 = t21;
end
out = [q11, q12; q21, q22];
end
This 64 times faster than the direct approach "out * M(:,:,ii)":
Elapsed time is 0.061287 seconds. Unrolled
Obviously Matlab calls very smart highly optimized libraries for the matrix multiplication, which treat the tiny input with the same hammer method as a 1000x1000 matrix.
But this unrolled version is such ugly, that I would hesitate to use it in productive code. For x = rand(2, 2, 100000) I get the timings for 1000 iterations:
Elapsed time is 1.377695 seconds. C-mex
Elapsed time is 2.872356 seconds. M with unrolled mtimes
Only a factor 2! Another example, that loops are not such bad in Matlab compared to C.
2 Commenti
Jos (10584)
il 7 Dic 2017
haha, I really liked my anonymous function approach, and did expect it to perform poorly, but that poor ... haha
Andrei Bobrov
il 6 Dic 2017
s = size(M)
out = M(:,:,1);
for ii = 2:s(3)
out = out*M(:,:,ii);
end
5 Commenti
Jan
il 7 Dic 2017
+1: This is the nicest solution. That the multiplication of 2x2 matrices is much faster with hard coded algorithm is not a problem of this solution.
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Matt J
il 7 Dic 2017
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Just wanted to note that, while my solution based on MTIMESX is not as fast as Jan's for the 2x2xN case, it is applicable to arbitrary MxMxN arrays,
Matt J
il 6 Dic 2017
The following is not a one-line solution (for that just stick it in a function file) and requires MTIMESX from the File Exchange. However, I do see a few factors speed-up over a conventional for-loop,
out=M;
while size(out,3)>1
n=size(out,3);
if mod(n,2)
n=n-1;
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=cat(3,mtimesx(A,B),out(:,:,n+1));
else
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=mtimesx(A,B);
end
end
5 Commenti
James Tursa
il 7 Dic 2017
Modificato: James Tursa
il 7 Dic 2017
Side Note: MTIMESX by default calls BLAS library routines for matrix multiply so that it matches MATLAB for-loop m-code result, whereas MTIMESX with the 'SPEED' option will use hand-coded inline matrix multiply code for up to 5x5 size slices which may not match MATLAB for-loop m-code result exactly.
Sometime back I had a beta version of MTIMESX that implemented the matrix equivalent versions of 'prod' and 'cumprod'. Maybe it is time I dust that off and finish the implementation/testing so I can publish it.
Matt J
il 7 Dic 2017
That is strange, since I still see significant speed-up even with
mtimesx MATLAB
Steven Lord
il 17 Set 2020
If you're using release R2020b or later, take a look at the pagemtimes function introduced in that release.
0 Commenti
Jos (10584)
il 6 Dic 2017
Here is one using recursion without a for-loop; not faster though, and somewhat mysterious, but just nice :) ...
M = randi(5,[2 2 4]) ; % data
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf,M,size(M,3)) % voila, it works!
3 Commenti
Jos (10584)
il 7 Dic 2017
It is the inline version of this recursive m-file:
function X = mprod(M,n)
% X = mprod(M) returns M(:,:,1) * M(:,:,2) * ... * M(:,:,end)
% where M is a 3D array
if nargin==1
X = mprod(M,size(M,3)) ;
elseif n < 2
X = M(:,:,1) ;
else
X = mprod(M,n-1) * M(:,:,n) ;
end
Vedere anche
Categorie
Scopri di più su Matrices and Arrays in Help Center e File Exchange
Prodotti
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!