diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 265fab5b1..131757610 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -9,13 +9,11 @@ import pytest import ot +from ot.backend import tf from ot.lp import wasserstein_1d -from ot.backend import get_backend_list, tf from scipy.stats import wasserstein_distance -backend_list = get_backend_list() - def test_emd_1d_emd2_1d_with_weights(): # test emd1d gives similar results as emd @@ -53,10 +51,7 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(w_v, G.sum(0)) -@pytest.mark.parametrize('nx', backend_list) def test_wasserstein_1d(nx): - from scipy.stats import wasserstein_distance - rng = np.random.RandomState(0) n = 100 @@ -105,8 +100,6 @@ def test_wasserstein_1d_type_devices(nx): @pytest.mark.skipif(not tf, reason="tf not installed") def test_wasserstein_1d_device_tf(): - if not tf: - return nx = ot.backend.TensorflowBackend() rng = np.random.RandomState(0) n = 10 diff --git a/test/test_ot.py b/test/test_ot.py index cbb63185a..5c6e6732b 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -103,8 +103,6 @@ def test_emd_emd2_types_devices(nx): @pytest.mark.skipif(not tf, reason="tf not installed") def test_emd_emd2_devices_tf(): - if not tf: - return nx = ot.backend.TensorflowBackend() n_samples = 100