Skip to content

Commit

Permalink
partial entropic fgw solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz committed Nov 27, 2024
1 parent 8e60257 commit 94a5e37
Show file tree
Hide file tree
Showing 3 changed files with 419 additions and 6 deletions.
4 changes: 4 additions & 0 deletions ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
solve_partial_gromov_linesearch,
entropic_partial_gromov_wasserstein,
entropic_partial_gromov_wasserstein2,
entropic_partial_fused_gromov_wasserstein,
entropic_partial_fused_gromov_wasserstein2,
)


Expand Down Expand Up @@ -180,4 +182,6 @@
"solve_partial_gromov_linesearch",
"entropic_partial_gromov_wasserstein",
"entropic_partial_gromov_wasserstein2",
"entropic_partial_fused_gromov_wasserstein",
"entropic_partial_fused_gromov_wasserstein2",
]
377 changes: 377 additions & 0 deletions ot/gromov/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,3 +1433,380 @@ def entropic_partial_gromov_wasserstein2(
return log_gw["partial_gw_dist"], log_gw
else:
return log_gw["partial_gw_dist"]


def entropic_partial_fused_gromov_wasserstein(
M,
C1,
C2,
p=None,
q=None,
reg=1.0,
m=None,
loss_fun="square_loss",
alpha=0.5,
G0=None,
numItermax=1000,
tol=1e-7,
symmetric=None,
log=False,
verbose=False,
):
r"""
Returns the entropic partial Fused Gromov-Wasserstein transport between
:math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and
:math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise
distance matrix :math:`\mathbf{M}` between node feature matrices.
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
\gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
.. math::
s.t. \ \gamma &\geq 0
\gamma \mathbf{1} &\leq \mathbf{a}
\gamma^T \mathbf{1} &\leq \mathbf{b}
\mathbf{1}^T \gamma^T \mathbf{1} = m
&\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- :math:`\mathbf{M}`: metric cost matrix between features across domains
- :math:`\mathbf{C_1}` is the metric cost matrix in the source space
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L`: quadratic loss function
- :math:`\Omega` is the entropic regularization term,
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
The formulation of the FGW problem has been proposed in
:ref:`[24] <references-entropic-partial-fused-gromov-wasserstein>` and the
partial GW in :ref:`[29] <references-entropic-partial-fused-gromov-wasserstein>`
Parameters
----------
M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
p : array-like, shape (ns,), optional
Distribution in the source space.
If let to its default value None, uniform distribution is taken.
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
reg: float, optional. Default is 1.
entropic regularization parameter
m : float, optional
Amount of mass to be transported (default:
:math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
loss_fun : str, optional
Loss function used for the solver either 'square_loss' or 'kl_loss'.
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
G0 : array-like, shape (ns, nt), optional
Initialization of the transportation matrix
numItermax : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
log : bool, optional
return log if True
verbose : bool, optional
Print information along iterations
Returns
-------
:math: `gamma` : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
.. _references-entropic-partial-fused-gromov-wasserstein:
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas "Optimal Transport for structured data with
application on graphs", International Conference on Machine Learning
(ICML). 2019.
.. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
Transport with Applications on Positive-Unlabeled Learning".
NeurIPS.
See Also
--------
ot.gromov.partial_fused_gromov_wasserstein: exact Partial Fused Gromov-Wasserstein
"""

arr = [M, C1, C2, G0]
if p is not None:
p = list_to_array(p)
arr.append(p)
if q is not None:
q = list_to_array(q)
arr.append(q)

nx = get_backend(*arr)

if p is None:
p = nx.ones(C1.shape[0], type_as=C1) / C1.shape[0]
if q is None:
q = nx.ones(C2.shape[0], type_as=C2) / C2.shape[0]

if m is None:
m = min(nx.sum(p), nx.sum(q))
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
elif m > min(nx.sum(p), nx.sum(q)):
raise ValueError(
"Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1)."
)

if G0 is None:
G0 = (
nx.outer(p, q) * m / (nx.sum(p) * nx.sum(q))
) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q.

else:
# Check marginals of G0
assert nx.any(nx.sum(G0, 1) <= p)
assert nx.any(nx.sum(G0, 0) <= q)

if symmetric is None:
symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(
C2, C2.T, atol=1e-10
)

# Setup gradient computation
fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx)
fC2t = fC2.T
if not symmetric:
fC1t, hC1t, hC2t = fC1.T, hC1.T, hC2.T

ones_p = nx.ones(p.shape[0], type_as=p)
ones_q = nx.ones(q.shape[0], type_as=q)

def f(G):
pG = nx.sum(G, 1)
qG = nx.sum(G, 0)
constC1 = nx.outer(nx.dot(fC1, pG), ones_q)
constC2 = nx.outer(ones_p, nx.dot(qG, fC2t))
return alpha * gwloss(constC1 + constC2, hC1, hC2, G, nx) + (
1 - alpha
) * nx.sum(G * M)

if symmetric:

def df(G):
pG = nx.sum(G, 1)
qG = nx.sum(G, 0)
constC1 = nx.outer(nx.dot(fC1, pG), ones_q)
constC2 = nx.outer(ones_p, nx.dot(qG, fC2t))
return alpha * gwggrad(constC1 + constC2, hC1, hC2, G, nx) + (
1 - alpha
) * nx.sum(G * M)
else:

def df(G):
pG = nx.sum(G, 1)
qG = nx.sum(G, 0)
constC1 = nx.outer(nx.dot(fC1, pG), ones_q)
constC2 = nx.outer(ones_p, nx.dot(qG, fC2t))
constC1t = nx.outer(nx.dot(fC1t, pG), ones_q)
constC2t = nx.outer(ones_p, nx.dot(qG, fC2))

return 0.5 * alpha * (
gwggrad(constC1 + constC2, hC1, hC2, G, nx)
+ gwggrad(constC1t + constC2t, hC1t, hC2t, G, nx)
) + (1 - alpha) * nx.sum(G * M)

cpt = 0
err = 1

loge = {"err": []}

while err > tol and cpt < numItermax:
Gprev = G0
M_entr = df(G0)
G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m)
if cpt % 10 == 0: # to speed up the computations
err = np.linalg.norm(G0 - Gprev)
if log:
loge["err"].append(err)
if verbose:
if cpt % 200 == 0:
print(
"{:5s}|{:12s}|{:12s}".format("It.", "Err", "Loss")
+ "\n"
+ "-" * 31
)
print("{:5d}|{:8e}|{:8e}".format(cpt, err, f(G0)))

cpt += 1

if log:
loge["partial_fgw_dist"] = f(G0)
return G0, loge
else:
return G0


def entropic_partial_fused_gromov_wasserstein2(
M,
C1,
C2,
p=None,
q=None,
reg=1.0,
m=None,
loss_fun="square_loss",
alpha=0.5,
G0=None,
numItermax=1000,
tol=1e-7,
symmetric=None,
log=False,
verbose=False,
):
r"""
Returns the entropic partial Fused Gromov-Wasserstein discrepancy between
:math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and
:math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise
distance matrix :math:`\mathbf{M}` between node feature matrices.
The function solves the following optimization problem:
.. math::
PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
\gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
.. math::
s.t. \ \gamma &\geq 0
\gamma \mathbf{1} &\leq \mathbf{a}
\gamma^T \mathbf{1} &\leq \mathbf{b}
\mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- :math:`\mathbf{M}`: metric cost matrix between features across domains
- :math:`\mathbf{C_1}` is the metric cost matrix in the source space
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L`: Loss function to account for the misfit between the similarity matrices.
- :math:`\Omega` is the entropic regularization term,
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
The formulation of the FGW problem has been proposed in
:ref:`[24] <references-entropic-partial-fused-gromov-wasserstein2>` and the
partial GW in :ref:`[29] <references-entropic-partial-fused-gromov-wasserstein2>`
Parameters
----------
M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
C1 : ndarray, shape (ns, ns)
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric cost matrix in the target space
p : array-like, shape (ns,), optional
Distribution in the source space.
If let to its default value None, uniform distribution is taken.
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
reg: float
entropic regularization parameter
m : float, optional
Amount of mass to be transported (default:
:math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
loss_fun : str, optional
Loss function used for the solver either 'square_loss' or 'kl_loss'.
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
G0 : ndarray, shape (ns, nt), optional
Initialization of the transportation matrix
numItermax : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
log : bool, optional
return log if True
verbose : bool, optional
Print information along iterations
Returns
-------
partial_fgw_dist: float
Partial Entropic Fused Gromov-Wasserstein discrepancy
log : dict
log dictionary returned only if `log` is `True`
.. _references-entropic-partial-fused-gromov-wasserstein2:
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas "Optimal Transport for structured data with
application on graphs", International Conference on Machine Learning
(ICML). 2019.
.. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
Transport with Applications on Positive-Unlabeled Learning".
NeurIPS.
"""
nx = get_backend(M, C1, C2)

T, log_pfgw = entropic_partial_fused_gromov_wasserstein(
M,
C1,
C2,
p,
q,
reg,
m,
loss_fun,
alpha,
G0,
numItermax,
tol,
symmetric,
True,
verbose,
)

log_pfgw["T"] = T

# setup for ot.solve_gromov
lin_term = nx.sum(T * M)
log_pfgw["quad_loss"] = log_pfgw["partial_fgw_dist"] - (1 - alpha) * lin_term
log_pfgw["lin_loss"] = lin_term * (1 - alpha)

if log:
return log_pfgw["partial_fgw_dist"], log_pfgw
else:
return log_pfgw["partial_fgw_dist"]
Loading

0 comments on commit 94a5e37

Please sign in to comment.