Vectorizing bsxfun

Hi,
I have two matrices (which are really lists of vectors) and would like a matrix of the pair-wise squared distances between all of them. The following code does what I want, but I'm curious if there's any way to vectorize this.
Thank you
rX=rand(nTrain,numDim);
rXClass=rand(nClass,numDim);
dists=zeros(nTrain,nClass);
for ii=1:nTrain
thisX=rX(ii,:);
dists(ii,:)=sum(bsxfun(@minus,thisX,rXClass).^2,2)/D;
end

Risposte (3)

the cyclist
the cyclist il 16 Ago 2011

0 voti

rX2 = permute(rX,[1 3 2]);
rXClass2 = permute(rXClass,[3 1 2]);
dists = sum(bsxfun(@minus,rX2,rXClass2).^2,3)/D;
Sean de Wolski
Sean de Wolski il 16 Ago 2011

0 voti

dists2 = squeeze(sum(bsxfun(@minus,rX,reshape(rXClass',[1 numDim nClass])).^2,2))/D;
With all three sizes equaling 150, I have your elementary for-loop running the fastest:
nTrain = 150;
numDim = 150;
nClass = 150;
D = 1;
rX=rand(nTrain,numDim);
rXClass=rand(nClass,numDim);
t1 = 0;
t2 = 0;
t3 = 0;
for jj = 1:50
tic
dists=zeros(nTrain,nClass);
for ii=1:nTrain
thisX=rX(ii,:);
dists(ii,:)=sum(bsxfun(@minus,thisX,rXClass).^2,2)/D;
end
t1 = t1+toc;
tic
dists2 = squeeze(sum(bsxfun(@minus,rX,reshape(rXClass',[1 numDim nClass])).^2,2))/D;
t2 = t2+toc;
tic
rX2 = permute(rX,[1 3 2]);
rXClass2 = permute(rXClass,[3 1 2]);
dists3 = sum(bsxfun(@minus,rX2,rXClass2).^2,3)/D;
t3 =t3+toc;
end
isequal(dists,dists2,dists3)
[t1 t2 t3]
ans =
1
ans =
3.1505 4.0336 4.0368

4 Commenti

Sean de Wolski
Sean de Wolski il 16 Ago 2011
The biggest time sink is actually the .^2. Removing that doubles the speed :(
Brendan
Brendan il 16 Ago 2011
Sadly, I'm not sure there's anything I can do about that one.
Sean de Wolski
Sean de Wolski il 16 Ago 2011
Nope, probably not. The only hope I can think of for that one is maybe with James' mtimesx on File Exchange.
the cyclist
the cyclist il 17 Ago 2011
If you prefer the vectorized code, they can both be sped up a fair amount by pulling apart the one-liner calculation of "dists" into separate lines for the bsxfun call, the squaring, and the sum

Accedi per commentare.

Andrei Bobrov
Andrei Bobrov il 16 Ago 2011

0 voti

my small contribution:
[a b] = meshgrid(1:nTrain,1:nClass);
dists = reshape(sum((rX(a(:),:) - rXClass(b(:),:)).^2,2),[],nTrain)'/D

Richiesto:

il 16 Ago 2011

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by