Why are pagemtimes, cross and pagetranspose functions slower on GPU?

I'm running my MATLAB code on my GPU (GTX 1060 Q-Max) and I'm having trouble with the execution time of pagemtimes and cross functions. The size of one of the matrices I'm working with (BBI1) is 2 x 2 x 418824 and the other ones are in similar sizes, so using GPU computing should be fine as these arrays are massive. I'm seeing significant performance loss in the following part of my code:
BBI1 = pagemtimes(pagemtimes(xiII_T,xiII),II) - pagemtimes(xiII,xiII_T);
BBJ1 = 2*pagemtimes(xiI,xiII_T) - pagemtimes(pagemtimes(xiI_T,xiII),II) - pagemtimes(xiII,xiI_T);
eInII = pagemtimes(pagemtimes(xiII_T,xiII),xiI) - pagemtimes(pagemtimes(xiI_T,xiII),xiII);
eIInI = pagemtimes(pagemtimes(xiI_T,xiI),xiII) - pagemtimes(pagemtimes(xiII_T,xiI),xiI);
outI = 2 * CC * ((1./AA) - (1./aa)) .* BBI1 + 2 * CC * 1./(aa.^3) .* pagemtimes(eInII,pagetranspose(eInII));
outJ = 2 * CC * ((1./AA) - (1./aa)) .* BBJ1 + 2 * CC * 1./(aa.^3) .* pagemtimes(eInII,pagetranspose(eIInI));
Performance test below:
Here is a minimal reproducible example below
matrix = rand(2,2,418824);
matrix_T = pagetranspose(matrix);
% measure on CPU
times1 = pagemtimes(pagemtimes(matrix,matrix_T),matrix);
% measure on GPU
matrix = gpuArray(matrix);
matrix_T = pagetranspose(matrix);
times2 = pagemtimes(pagemtimes(matrix,matrix_T),matrix);
which gives the output:
Elapsed time is 0.012834 seconds.
Elapsed time is 0.440265 seconds.
Are pagemtimes, pagetranspose and cross simply not optimized for my case or am I missing/misusing something? In all of the examples of these functions, I've seen that they provided great performance gains. How can one accelerate the codes above?
Joss Knight
Joss Knight il 14 Giu 2021
Have you tried the computation in single precision? Your card's double precision performance isn't up to much.

Max Heiken
Max Heiken il 13 Giu 2021
Take everything I write with a grain of salt, since I have never used GPU features myself in the past and this is just from experimenting with your reproducible example.
It seems to me that indeed pagemtimes is not optimized for your use case. As a test, I tried
matrix = rand(1000,1000,10);
and lo and behold, there was a 10x increase in speed for the GPU version compared to the CPU version.
Elapsed time is 0.077263 seconds.
Elapsed time is 0.007821 seconds.
The elapsed time scales linearly with the number of pages in both CPU and GPU case. This tells me, that the matrix multiplication part is the one that is parallelized, and the pages are sequential. This is obviously not helpful in your case.
I came across this answer which suggested to just expand out the equation. This will make the pages parallelized, and the matrix multiplication part sequential. Adding
m11 = matrix(1, 1, :);
m12 = matrix(1, 2, :);
m21 = matrix(2, 1, :);
m22 = matrix(2, 2, :);
times3 = [m21.*(m11.*m21 + m12.*m22) + m11.*(m11.^2 + m12.^2), m22.*(m11.*m21 + m12.*m22) + m12.*(m11.^2 + m12.^2);
m11.*(m11.*m21 + m12.*m22) + m21.*(m21.^2 + m22.^2), m12.*(m11.*m21 + m12.*m22) + m22.*(m21.^2 + m22.^2)];
at the bottom (with the original matrix = rand(2,2,418824)) gives
Elapsed time is 0.003063 seconds. % CPU (pagemtimes)
Elapsed time is 0.042687 seconds. % GPU (pagemtimes)
Elapsed time is 0.000707 seconds. % GPU (expanded out equation)
on my GTX 1070 TI.
To get the above expanded equation, I used the symbolic toolbox
syms m11 m12 m21 m22 real
m = [m11 m12; m21 m22]
Mert Solen
Mert Solen il 13 Giu 2021
Thank you for your answer. It's really weird because MATLAB's own page for pagemtimes says that it's optimized for GPU arrays. In fact, parallel computing toolbox has its own pagemtimes function in its class as a method that uses pagefun(@mtimes...). Is there any other workaround for this issue as I will be using page-wise multiplication many times and it's not really efficitient memory-wise to create and hold separate vectors for each case?
By the way, the same issue is there for pagetranspose and cross functions. I don't know if splitting the matrix will help pagetranspose efficiency, and if cross can be implemented in any other way. It may be essential to find the common issue for all of these functions.

