Skip to content

Commit

Permalink
better dist and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rflamary committed Nov 27, 2024
1 parent 05341bb commit aa95c82
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 15 deletions.
43 changes: 36 additions & 7 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@
from inspect import signature
from .backend import get_backend, Backend, NumpyBackend, JaxBackend

__time_tic_toc = time.time()
__time_tic_toc = time.perf_counter()


def tic():
r"""Python implementation of Matlab tic() function"""
global __time_tic_toc
__time_tic_toc = time.time()
__time_tic_toc = time.perf_counter()


def toc(message="Elapsed time : {} s"):
r"""Python implementation of Matlab toc() function"""
t = time.time()
t = time.perf_counter()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc


def toq():
r"""Python implementation of Julia toc() function"""
t = time.time()
t = time.perf_counter()
return t - __time_tic_toc


Expand Down Expand Up @@ -291,11 +291,12 @@ def euclidean_distances(X, Y, squared=False):
return c


def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None, nx=None):
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
from all compatible backends for the following metrics:
'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'.
Parameters
----------
Expand All @@ -315,7 +316,8 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
w : array-like, rank 1
Weights for the weighted metrics.
nx : Backend, optional
Backend to perform computations on. If omitted, the backend defaults to that of `x1`.
Returns
-------
Expand All @@ -324,12 +326,39 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
distance matrix computed with given metric
"""
if nx is None:
nx = get_backend(x1, x2)
if x2 is None:
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
elif metric == "euclidean":
return euclidean_distances(x1, x2, squared=False)
elif metric == "cityblock":
return nx.sum(nx.abs(x1[:, None, :] - x2[None, :, :]), axis=2)
elif metric == "minkowski":
if w is None:
return nx.power(
nx.sum(nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p), axis=2),
1 / p,
)
return nx.power(
nx.sum(
w[None, None, :] * nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p),
axis=2,
),
1 / p,
)
elif metric == "cosine":
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
elif metric == "correlation":
x1 = x1 - nx.mean(x1, axis=1)[:, None]
x2 = x2 - nx.mean(x2, axis=1)[:, None]
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
else:
if not get_backend(x1, x2).__name__ == "numpy":
raise NotImplementedError()
Expand Down
53 changes: 45 additions & 8 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@
import numpy as np
import sys
import pytest
import scipy

lst_metrics = [
"euclidean",
"sqeuclidean",
"cityblock",
"cosine",
"minkowski",
"correlation",
]

lst_all_metrics = lst_metrics + [
"braycurtis",
"canberra",
"chebyshev",
"dice",
"hamming",
"jaccard",
"matching",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
"yule",
]


def get_LazyTensor(nx):
Expand Down Expand Up @@ -185,7 +210,7 @@ def test_dist():

assert D4[0, 1] == D4[1, 0]

# dist shoul return squared euclidean
# dist should return squared euclidean
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)

Expand Down Expand Up @@ -230,20 +255,32 @@ def test_dist():
ot.dist(x, x, metric="wminkowski")


def test_dist_backends(nx):
@pytest.mark.parametrize("metric", lst_metrics)
def test_dist_backends(nx, metric):
n = 100
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
x1 = nx.from_numpy(x)

lst_metric = ["euclidean", "sqeuclidean"]
D = ot.dist(x, x, metric=metric)
D1 = ot.dist(x1, x1, metric=metric)

for metric in lst_metric:
D = ot.dist(x, x, metric=metric)
D1 = ot.dist(x1, x1, metric=metric)
# low atol because jax forces float32
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)

# low atol because jax forces float32
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)

@pytest.mark.parametrize("metric", lst_all_metrics)
def test_dist_vs_cdist(metric):
n = 10

rng = np.random.RandomState(0)
x = rng.randn(n, 2)
y = rng.randn(n + 1, 2)

D = ot.dist(x, y, metric=metric)
D2 = scipy.spatial.distance.cdist(x, y, metric=metric)

np.testing.assert_allclose(D, D2, atol=1e-15)


def test_dist0():
Expand Down

0 comments on commit aa95c82

Please sign in to comment.