How to realize Maximum Mean Discrepancy (MMD) based on pdist2 function ?

35 visualizzazioni (ultimi 30 giorni)
The MaximumMean Discrepancy (MMD) is a statistical testing method and one of the most widely used distance metrics in transfer learning, designed to assess the similarity between two probability distributions. The core concept of MMD involves mapping samples from two distributions into a high-dimensional or infinite-dimensional feature space and measuring the difference in their marginal distributions. MMD is defined as shown in the following equation:
where m and n present the numbers of samples in X and Y, k is the kernel function, common ones include Gaussian kernel, linear kernel, and polynomial kernel.
I want to use MATLAB Deep Learning Toolbox for calculating the MMD between two minibatch samples, the example code are shown below:
sourceFeats = dlarray(rand(3,5),"CB");
targetFeats = dlarray(rand(3,10),"CB");
% Calculate the kernal matrix between samples
sourceFeats = stripdims(sourceFeats); % Strip dims for calculation
targetFeats = stripdims(targetFeats);
kerXX = sourceFeats.'*sourceFeats; % linear kernal
kerYY = targetFeats.'*targetFeats; % linear kernal
kerXY = sourceFeats.'*targetFeats; % linear kernal
MMDsq2 = mean(kerXX,"all") + mean(kerYY,"all") - 2*mean(kerXY,"all");
% ------------------------------------------ %
% I want to use pdist2 to calculate pairwise distance for efficiency
pdist2(sourceFeats.',targetFeats.',@myKerFcn); % error, pdist2 can not permit dlarray object
Error using zeros
Class name must be a class that supports ZEROS, such as "double" or "single".

Error in statslib.internal.pdist2 (line 754)
D = zeros(nx,ny,class(D));
^^^^^^^^^^^^^^^^^^^^^^^^^^
Error in pdist2 (line 158)
[varargout{1:nargout}] = statslib.internal.pdist2(X,Y,varargin{:});
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pdist2(extractdata(sourceFeats).',extractdata(targetFeats).',@myKerFcn); % lost of dlarray object!
function dis = myKerFcn(obsX,obsY)
arguments (Input)
obsX (1,:)
obsY (:,:)
end
arguments (Output)
dis (:,1)
end
dis = exp(-vecnorm(obsX-obsY,2,2)/2);
end
I want to use pdist2 function for calculating the pairwise distances between two minibatch samples, yet the pdist2 function can not permit dlarray input arguments. Can anyone give me some advices about how to calculate the pairwise distance between two minibatch samples without lossing the dlarray object since the dlarray object is important for gradient tracing in deep learning.
  3 Commenti
Claire
Claire il 28 Nov 2025 alle 8:54
The pdist2 function is designed for numerical computation: it is a standard MATLAB function primarily used to handle traditional numeric data types such as double, single, and int. Its internal implementation relies on functions like zeros to pre-allocate memory, and these functions do not know how to handle dlarray objects and their complex dimension labels. Therefore, when you pass a dlarray to pdist2, it will throw an error.
extractdata strips away metadata: As you discovered, using extractdata can convert a dlarray back into a regular double array, allowing pdist2 to run. However, this also means you lose all the information from the dlarray—especially its dimension labels (e.g., "CB")—which prevents subsequent automatic differentiation. This defeats the original purpose of using it in a deep learning training loop.
You cannot directly use pdist2 to compute distances involving dlarray objects.
Chuguang Pan
Chuguang Pan il 28 Nov 2025 alle 9:00
Modificato: Chuguang Pan il 28 Nov 2025 alle 9:00
@Claire.Thanks for your reply. I would like to know if there are any other convenient ways to calculate the pairwise distance between two minibatch samples without lossing the gradient information.

Accedi per commentare.

Risposte (0)

Categorie

Scopri di più su Image Data Workflows in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by