Skip to content

Commit

Permalink
Remove redundant parametrization from wasserstein_1d tests (#517)
Browse files Browse the repository at this point in the history
* No need to parametrize for nx param

* Remove redundant check for TF backend

* Remove redundant TF check from emd2 test

---------

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
kachayev and rflamary authored Sep 6, 2023
1 parent 5331480 commit 064898d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 10 deletions.
9 changes: 1 addition & 8 deletions test/test_1d_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
import pytest

import ot
from ot.backend import tf
from ot.lp import wasserstein_1d

from ot.backend import get_backend_list, tf
from scipy.stats import wasserstein_distance

backend_list = get_backend_list()


def test_emd_1d_emd2_1d_with_weights():
# test emd1d gives similar results as emd
Expand Down Expand Up @@ -53,10 +51,7 @@ def test_emd_1d_emd2_1d_with_weights():
np.testing.assert_allclose(w_v, G.sum(0))


@pytest.mark.parametrize('nx', backend_list)
def test_wasserstein_1d(nx):
from scipy.stats import wasserstein_distance

rng = np.random.RandomState(0)

n = 100
Expand Down Expand Up @@ -105,8 +100,6 @@ def test_wasserstein_1d_type_devices(nx):

@pytest.mark.skipif(not tf, reason="tf not installed")
def test_wasserstein_1d_device_tf():
if not tf:
return
nx = ot.backend.TensorflowBackend()
rng = np.random.RandomState(0)
n = 10
Expand Down
2 changes: 0 additions & 2 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def test_emd_emd2_types_devices(nx):

@pytest.mark.skipif(not tf, reason="tf not installed")
def test_emd_emd2_devices_tf():
if not tf:
return
nx = ot.backend.TensorflowBackend()

n_samples = 100
Expand Down

0 comments on commit 064898d

Please sign in to comment.