"Partial" matrix multiplication

13 visualizzazioni (ultimi 30 giorni)
Hieu Pham
Hieu Pham il 1 Giu 2015
Risposto: Joss Knight il 26 Giu 2015
Suppose that I have two matrices, A and B, both have size Dx(3N). I want to multiply each block of 3 consecutive columns in A with the transpose of the corresponding block of 3 consecutive columns in B (the result of each of these multiplications would be a DxD matrix). What are the best ways to do this?
For example, let's say
A = [a_1, a_2, a_3, b_1, b_2, b_3, c_1, c_2, c_3]
B = [x_1, x_2, x_3, y_1, y_2, y_3, z_1, z_2, z_3]
where a_i, b_i, c_i, x_i, y_i, z_i all have size Dx1. I want to compute
[a_1, a_2, a_3]*[x_1, x_2, x_3]'
[b_1, b_2, b_3]*[y_1, y_2, y_3]'
[c_1, c_2, c_3]*[z_1, z_2, z_3]'
and of course, I need to store the results.

Risposta accettata

Azzi Abdelmalek
Azzi Abdelmalek il 1 Giu 2015
Modificato: Azzi Abdelmalek il 1 Giu 2015
A=randi(9,3,9)
B=randi(9,3,9)
idx=1:3:size(A,2)
out=cell2mat(arrayfun(@(x) A(:,x:x+2)*B(:,x:x+2)',idx,'un',0))
  1 Commento
Hieu Pham
Hieu Pham il 2 Giu 2015
Thank you for your answer. However, just recently I did a comparison.
function test(D,n)
A = rand(D,3*n);
B = rand(D,3*n);
tic; x = test1(A,B); toc
tic; x = test2(A,B); toc
end
function [out] = test1(A,B)
idx=1:3:size(A,2);
out=cell2mat(arrayfun(@(x) A(:,x:x+2)*B(:,x:x+2)',idx,'un',0));
end
function [out] = test2(A,B)
D = size(B,1);
n = size(B,2) / 3;
out = zeros(D,D*n);
for i = 1:n
out(:,D*(i-1)+1:D*i) = A(:,3*i-2:3*i) * B(:,3*i-2:3*i)';
end
end
And it turns out that your method is (much) slower than a simple for loop...
>> test(1000, 128)
Elapsed time is 5.499530 seconds.
Elapsed time is 1.431573 seconds.

Accedi per commentare.

Più risposte (2)

James Tursa
James Tursa il 2 Giu 2015
Modificato: James Tursa il 2 Giu 2015
If you have a C compiler installed, you can use an FEX submission called mtimesx which does nD matrix multiply with built-in transpose capability (does a virtual transpose, not an actual transpose):
[m,n] = size(A);
n3 = n/3;
Ar = reshape(A,m,3,n3);
Br = reshape(B,m,3,n3);
C = mtimesx(Ar,Br,'t','speedomp');
You can find mtimesx here:
Another option is mmx, but you will have to do the nD transpose manually via a permute:
[m,n] = size(A);
n3 = n/3;
Ar = reshape(A,m,3,n3);
Br = reshape(B,m,3,n3);
C = mmx(Ar,permute(Br,[2 1 3]));
If you don't have a C compiler installed, you can use a different m-file based routine called multiprod:
[m,n] = size(A);
n3 = n/3;
Ar = reshape(A,m,3,n3);
Br = reshape(B,m,3,n3);
C = multiprod(Ar,permute(Br,[2 1 3]));
You can find multiprod here:
  2 Commenti
Hieu Pham
Hieu Pham il 2 Giu 2015
Modificato: Hieu Pham il 2 Giu 2015
Dear James,
Thank you for your answer. I do have a C compiler so technically I can try mtimesx and mmx. However, at the end of the day, I will have to run my code on a GPU. Are you aware of a way to use mtimesx and/or mmx with GPU? For e.g., a CUDA version?
Last but not least, the last method you proposed seems to be even slower than the one proposed by @Azzi above.
>> test(1000, 128)
Elapsed time is 2.536413 seconds.
Elapsed time is 1.380087 seconds.
Elapsed time is 13.902408 seconds.
The three runtimes correspond to Azzi's method, for loop and your method.
James Tursa
James Tursa il 2 Giu 2015
None of these methods have CUDA versions to my knowledge. For CUDA, you may need to write the code from scratch. If the row size is not too big, hand coding the individual (m x 3) * (m x 3)' multiplies directly element-by-element might be faster than using loops.

Accedi per commentare.


Joss Knight
Joss Knight il 26 Giu 2015
If you're running this on a GPU using Parallel Computing Toolbox, as you say, then you can use pagefun:
rows = size(A,1);
assert(size(B,1) == rows);
A = reshape(gpuArray(A), rows, 3, []);
B = reshape(gpuArray(B), rows, 3, []);
Bt = pagefun(@transpose, B);
C = pagefun(@mtimes, A, Bt);
The result C is a rows x rows x (cols/3) ND array.
There is currently no equivalent to pagefun for the CPU, but the CPU will work fine with a loop.

Categorie

Scopri di più su Linear Algebra in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by