Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Prevent line search from evaluating cost outside of the interpolation range #504

Merged
merged 7 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)


## 0.9.1
Expand Down
17 changes: 12 additions & 5 deletions ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
Expand Down Expand Up @@ -56,7 +56,7 @@ def line_search_armijo(
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
alpha_min : float, optional
alpha_min : float, default=0.
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
Expand Down Expand Up @@ -89,6 +89,14 @@ def line_search_armijo(
fc = [0]

def phi(alpha1):
# it's necessary to check boundary condition here for the coefficient
# as the callback could be evaluated for negative value of alpha by
# `scalar_search_armijo` function here:
#
# https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686
#
# see more details https://github.com/PythonOT/POT/issues/502
alpha1 = np.clip(alpha1, alpha_min, alpha_max)
# The callable function operates on nx backend
fc[0] += 1
alpha10 = nx.from_numpy(alpha1)
Expand All @@ -109,13 +117,12 @@ def phi(alpha1):

derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min)

if alpha is None:
return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
alpha = np.clip(alpha, alpha_min, alpha_max)
return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)


Expand Down
11 changes: 7 additions & 4 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
import pytest
import warnings

import ot
from ot.datasets import make_data_classif
Expand Down Expand Up @@ -158,15 +159,17 @@ def test_sinkhorn_l1l2_transport_class(nx):
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Xs, ys = make_data_classif('3gauss', ns, random_state=42)
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)

Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)

otda = ot.da.SinkhornL1l2Transport()
otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)

# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
with warnings.catch_warnings():
warnings.simplefilter("error")
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "cost_")
assert hasattr(otda, "coupling_")
assert hasattr(otda, "log_")
Expand Down
Loading