Azzera filtri
Azzera filtri

Computing the Gaussian Wasserstein distance

22 visualizzazioni (ultimi 30 giorni)
Matteo Tesori
Matteo Tesori il 23 Mar 2021
Problem
I have to compute the Wasserstein distance between two bivariate Gaussian distributions with means , and covariances , .
According to equation 9 of this paper, in the Gaussian case the Wasserstein distance admits the following analytic expression
My problem is to implement this equation.
Tried solution
It is not clear to me what is intended as square root of a matrix, and probably this is the core of my problem. I suppose that in equation the square root of a matrix is its Cholesky factor.
According to this interpretation, I've written the following code to perform the Wasserstein distance given the parameters of the two Gaussian distributions
function [dd] = wass_dist(m1, Sigma1, m2, Sigma2)
sqrtSigma1 = chol(Sigma1);
sqrt_temp = chol(sqrtSigma1 * Sigma2 * sqrtSigma1);
ddm = (m1 - m2)' * (m1 - m2);
ddSigma = trace(Sigma1 + Sigma2 - 2 * sqrt_temp);
dd = ddm + ddSigma;
end
Firstly, such code gives problems because often the matrix sqrtSigma1 * Sigma2 * sqrtSigma1 is not positive definite. I suspect that this problem can be fixed in two manners: by transposing the first term, i.e. by considering sqrtSigma1' * Sigma2 * sqrtSigma1, or by transposing the third term, i.e. by considering sqrtSigma1 * Sigma2 * sqrtSigma1'. However in the aforermentioned paper, and in other papers as well, the given formula to compute the Wasserstein distance is always written in the form without transposition, meaning that does not contain any typo.
At this point I've tried to compute the Wasserstein distance of two identical Gaussian distributions according to the following modified function
function [dd] = wass_dist(m1, Sigma1, m2, Sigma2)
sqrtSigma1 = chol(Sigma1);
sqrt_temp = chol(sqrtSigma1 * Sigma2 * sqrtSigma1'); % third term transposed
ddm = (m1 - m2)' * (m1 - m2);
ddSigma = trace(Sigma1 + Sigma2 - 2 * sqrt_temp);
dd = ddm + ddSigma;
end
the output, as not expected, is not zero because ddSigma doesn't get a null value. More precisely, the input arguments that I've tried are
m1 = [500 500]', Sigma1 = 1.0e+04 * [1.6767 -0.3302; -0.3302 0.0826]
m2 = m1, Sigma2 = Sigma1
and the relative output is
dd = 6.6613
where ddm = 0 and ddSigma = 6.6613. This fact is a strong suggestion that there's something wrong in the code, maybe because the square root considered is not the Cholesky factor. I've also tried the modified version where is the first term to be transposed, and the result is even worse. With the previous inputs, the resul is
dd = ddSigma = 1.3002e+03
Question
Is it correct my code to compute equation ? If not, how can I fixed it?
  1 Commento
Ogul Can Yurdakul
Ogul Can Yurdakul il 16 Mag 2023
Hey Matteo,
I believe using chol() is your problem. Cholesky decomposition of Σ gives you a matrix C such that , and the problem (I assume) is that this is a transposed square root. Using a symmetric squre root, meaning a matrix square root S such that might just solve your problem. I use the below code with no problems.
function [dist] = GW_dist(mu_1, cov_1, mu_2, cov_2)
dist = (mu_1 - mu_2).' * (mu_1 - mu_2);
dist = dist + trace(cov_1 + cov_2 - 2*(cov_1^0.5 * cov_2 * cov_1^0.5)^0.5);
dist = dist^0.5;
end
Hope it helps!

Accedi per commentare.

Risposte (0)

Categorie

Scopri di più su MATLAB in Help Center e File Exchange

Prodotti

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by