Skip to content

Commit

Permalink
Merge branch 'master' into fix-line-search-zero-cost
Browse files Browse the repository at this point in the history
  • Loading branch information
kachayev authored Sep 1, 2023
2 parents 984f982 + 5331480 commit 47d681a
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 26 deletions.
4 changes: 2 additions & 2 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,8 @@ class LinearGWTransport(LinearTransport):
References
----------
.. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein
distances between Gaussian distributions. Journal of Applied Probability,
59(4), 1178-1198.
distances between Gaussian distributions. Journal of Applied Probability,
59(4), 1178-1198.
"""

Expand Down
44 changes: 27 additions & 17 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
Expand All @@ -92,8 +92,8 @@ def entropic_gromov_wasserstein(
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
Otherwise G0 will be used as initial transport of the solver. G0 is not
required to satisfy marginal constraints but we strongly recommand it
to correcly estimate the GW distance.
required to satisfy marginal constraints but we strongly recommend it
to correctly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Expand Down Expand Up @@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
if solver not in ['PGD', 'PPA']:
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)

if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

C1, C2 = list_to_array(C1, C2)
arr = [C1, C2]
if p is not None:
Expand Down Expand Up @@ -280,7 +283,7 @@ def entropic_gromov_wasserstein2(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
Expand Down Expand Up @@ -373,8 +376,8 @@ def entropic_gromov_barycenters(
lambdas : list of float, optional
List of the `S` spaces' weights.
If let to its default value None, uniform weights are taken.
loss_fun : callable, optional
tensor-matrix multiplication function based on specific loss function
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
symmetric : bool, optional.
Expand Down Expand Up @@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

Cs = list_to_array(*Cs)
arr = [*Cs]
if ps is not None:
Expand Down Expand Up @@ -459,7 +465,6 @@ def entropic_gromov_barycenters(

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs)

Expand Down Expand Up @@ -550,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
Otherwise G0 will be used as initial transport of the solver. G0 is not
required to satisfy marginal constraints but we strongly recommand it
to correcly estimate the GW distance.
required to satisfy marginal constraints but we strongly recommend it
to correctly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Expand Down Expand Up @@ -611,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
if solver not in ['PGD', 'PPA']:
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)

if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

M, C1, C2 = list_to_array(M, C1, C2)
arr = [M, C1, C2]
if p is not None:
Expand Down Expand Up @@ -762,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
Expand All @@ -775,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
Otherwise G0 will be used as initial transport of the solver. G0 is not
required to satisfy marginal constraints but we strongly recommand it
to correcly estimate the GW distance.
required to satisfy marginal constraints but we strongly recommend it
to correctly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Expand Down Expand Up @@ -857,8 +865,8 @@ def entropic_fused_gromov_barycenters(
lambdas : list of float, optional
List of the `S` spaces' weights.
If let to its default value None, uniform weights are taken.
loss_fun : callable, optional
tensor-matrix multiplication function based on specific loss function
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
symmetric : bool, optional.
Expand Down Expand Up @@ -907,6 +915,9 @@ def entropic_fused_gromov_barycenters(
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

Cs = list_to_array(*Cs)
Ys = list_to_array(*Ys)
arr = [*Cs, *Ys]
Expand Down Expand Up @@ -977,7 +988,6 @@ def entropic_fused_gromov_barycenters(

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs)

Expand Down
12 changes: 10 additions & 2 deletions ot/gromov/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None):
Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
constC : array-like, shape (ns, nt)
Expand Down Expand Up @@ -118,6 +119,8 @@ def h1(a):

def h2(b):
return nx.log(b + 1e-15)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

constC1 = nx.dot(
nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
Expand Down Expand Up @@ -402,11 +405,12 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
T : array-like, shape (ns, nt)
Coupling between source and target spaces
p : array-like, shape (ns,)
loss_fun : str, optional
Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
constC : array-like, shape (ns, nt)
Expand Down Expand Up @@ -446,6 +450,10 @@ def h1(a):

def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
raise NotImplementedError()
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.")

constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
nx.ones((1, C2.shape[0]), type_as=p))
Expand Down
79 changes: 74 additions & 5 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,32 @@ def line_search(cost, G, deltaG, Mi, cost_G):
np.testing.assert_allclose(res, Gb, atol=1e-06)


@pytest.mark.parametrize('loss_fun', [
'square_loss',
'kl_loss',
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_gw_helper_validation(loss_fun):
n_samples = 20 # nb samples
mu = np.array([0, 0])
cov = np.array([[1, 0], [0, 1]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
p = ot.unif(n_samples)
q = ot.unif(n_samples)
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
ot.gromov.init_matrix(C1, C2, p, q, loss_fun=loss_fun)


@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
@pytest.mark.parametrize('loss_fun', [
'square_loss',
'kl_loss',
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_entropic_gromov(nx, loss_fun):
n_samples = 10 # nb samples

mu_s = np.array([0, 0])
Expand All @@ -319,10 +342,10 @@ def test_entropic_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)

G, log = ot.gromov.entropic_gromov_wasserstein(
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
C1, C2, None, q, loss_fun, symmetric=None, G0=G0,
epsilon=1e-2, max_iter=10, verbose=True, log=True)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None,
C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None,
epsilon=1e-2, max_iter=10, verbose=True, log=False
))

Expand All @@ -333,11 +356,40 @@ def test_entropic_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov


@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.mark.parametrize('loss_fun', [
'square_loss',
'kl_loss',
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_entropic_gromov2(nx, loss_fun):
n_samples = 10 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)

xt = xs[::-1].copy()

p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()

C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)

gw, log = ot.gromov.entropic_gromov_wasserstein2(
C1, C2, p, None, 'kl_loss', symmetric=True, G0=None,
C1, C2, p, None, loss_fun, symmetric=True, G0=None,
max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b,
C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b,
max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)

Expand Down Expand Up @@ -2030,6 +2082,23 @@ def line_search(cost, G, deltaG, Mi, cost_G):
np.testing.assert_allclose(res, Gb, atol=1e-06)


@pytest.mark.parametrize('loss_fun', [
'square_loss',
pytest.param('kl_loss', marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_gw_semirelaxed_helper_validation(loss_fun):
n_samples = 20 # nb samples
mu = np.array([0, 0])
cov = np.array([[1, 0], [0, 1]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
p = ot.unif(n_samples)
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun)


def test_semirelaxed_fgw(nx):
rng = np.random.RandomState(0)
list_n = [16, 8]
Expand Down

0 comments on commit 47d681a

Please sign in to comment.