Why is pagemtimes slower than just coding up the matrix multiplica​tion?Espec​ially on GPU.

8 visualizzazioni (ultimi 30 giorni)
I'm going to use the Pagemtimes function in my custom loss function. But when I train my network with GPU, it doesn't work very well. I found some people asking questions about this in the community, but there wasn't an answer that could be taken on board. Here are my tests for an examples of questions already in the community.
function C=pagemtimes_version(A,B,E,F)
C = pagemtimes(F,(B+pagemtimes(E,A)));
end
function C=direct(A,B,E,F)
C(:,:,1,1) = ...
F(:,:,1,1).*(A(:,:,1,1).*E(:,:,1,1)+B(:,:,1,1)) +...
F(:,:,1,2).*(A(:,:,1,1).*E(:,:,1,2)+B(:,:,1,1)) +...
F(:,:,1,3).*(A(:,:,1,1).*E(:,:,1,3)+B(:,:,1,1));
C(:,:,2,1) = ...
F(:,:,2,1).*(A(:,:,2,1).*E(:,:,2,1)+B(:,:,2,1)) +...
F(:,:,2,2).*(A(:,:,2,1).*E(:,:,2,2)+B(:,:,2,1)) +...
F(:,:,2,3).*(A(:,:,2,1).*E(:,:,2,3)+B(:,:,2,1));
C(:,:,3,1) = ...
F(:,:,3,1).*(A(:,:,3,1).*E(:,:,3,1)+B(:,:,3,1)) +...
F(:,:,3,2).*(A(:,:,3,1).*E(:,:,3,2)+B(:,:,3,1)) +...
F(:,:,3,3).*(A(:,:,3,1).*E(:,:,3,3)+B(:,:,3,1));
end
Since some of the replies suggested a single-precision test, I'll show it in single-precision first.
Nx=1000;
Ny=1000;
[E,F] = deal(gpuArray(single(rand(Nx,Ny,3,3))));
[A,B] = deal(gpuArray(single(rand(Nx,Ny,3,1))));
timeit(@()direct(A,B,E,F))
ans = 2.9201e-04
timeit(@()pagemtimes_version(A,B,E,F))
ans = 0.0045
The difference is almost 20 times, and the larger the array the greater the difference in effect, when Nx,Ny takes 5000 the difference is 1000 times (0.1/10^-4)
[E,F] = deal(single(rand(Nx,Ny,3,3)));
[A,B] = deal(single(rand(Nx,Ny,3,1)));
timeit(@()direct(A,B,E,F))
ans = 0.0421
timeit(@()pagemtimes_version(A,B,E,F))
ans = 0.0514
GPU even slower than CPU in double-precision.
[E,F] = deal(gpuArray(rand(Nx,Ny,3,3)));
[A,B] = deal(gpuArray(rand(Nx,Ny,3,1)));
timeit(@()direct(A,B,E,F))
ans = 2.6526e-04
timeit(@()pagemtimes_version(A,B,E,F))
ans = 0.1517
[E,F] = deal(rand(Nx,Ny,3,3));
[A,B] = deal(rand(Nx,Ny,3,1));
timeit(@()direct(A,B,E,F))
ans = 0.0874
timeit(@()pagemtimes_version(A,B,E,F))
ans = 0.1163
Pagemtimes are really handy. But it doesn't look good for double precision data and on the GPU. I would like to know if there is any way to fix
  4 Commenti

Accedi per commentare.

Risposte (2)

Joss Knight
Joss Knight il 31 Ott 2024
Your implementation is incorrect I'm afraid, you are using elementwise times rather than mtimes. You are also using timeit instead of gputimeit which is unfairly penalizing the pagemtimes code because it is running synchronously.
  1 Commento
Hongbo Sun
Hongbo Sun il 31 Ott 2024
Modificato: Hongbo Sun il 31 Ott 2024
Thank you. I corrected the code in the example, and after using gputimeit correctly, everything worked.I'll look elsewhere to speed up my program.

Accedi per commentare.


the cyclist
the cyclist il 31 Ott 2024
It seems to me that the two functions are not calculating the same thing, based on the size of their respective outputs:
rng default
Nx=1000;
Ny=1000;
[E,F] = deal(single(rand(Nx,Ny,3,3)));
[A,B] = deal(single(rand(Nx,Ny,3,1)));
C1 = pagemtimes_version(A,B,E,F);
C2 = direct(A,B,E,F);
size(C1)
ans = 1×4
1000 1000 3 3
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
size(C2)
ans = 1×3
1000 1000 3
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
function C=pagemtimes_version(A,B,E,F)
C = pagemtimes(F,(B+pagemtimes(E,A)));
end
function C=direct(A,B,E,F)
C(:,:,1,1) = ...
F(:,:,1,1).*(A(:,:,1,1).*E(:,:,1,1)+B(:,:,1,1)) +...
F(:,:,1,2).*(A(:,:,1,1).*E(:,:,1,2)+B(:,:,1,1)) +...
F(:,:,1,3).*(A(:,:,1,1).*E(:,:,1,3)+B(:,:,1,1));
C(:,:,2,1) = ...
F(:,:,2,1).*(A(:,:,2,1).*E(:,:,2,1)+B(:,:,2,1)) +...
F(:,:,2,2).*(A(:,:,2,1).*E(:,:,2,2)+B(:,:,2,1)) +...
F(:,:,2,3).*(A(:,:,2,1).*E(:,:,2,3)+B(:,:,2,1));
C(:,:,3,1) = ...
F(:,:,3,1).*(A(:,:,3,1).*E(:,:,3,1)+B(:,:,3,1)) +...
F(:,:,3,2).*(A(:,:,3,1).*E(:,:,3,2)+B(:,:,3,1)) +...
F(:,:,3,3).*(A(:,:,3,1).*E(:,:,3,3)+B(:,:,3,1));
end
  1 Commento
Hongbo Sun
Hongbo Sun il 31 Ott 2024
Modificato: Hongbo Sun il 31 Ott 2024
You're right, I used the example from the previous question but didn't think it through. But the main problem is the incorrect use of the timeit function.

Accedi per commentare.

Prodotti


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by