Skip to content

Commit

Permalink
Sinkhorn Distance, Chamfer Distance, and pairwise distances between p…
Browse files Browse the repository at this point in the history
…oints
  • Loading branch information
fwilliams committed Jul 11, 2019
1 parent 66a5531 commit 4412e71
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 8 deletions.
16 changes: 8 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
endif()

# Create the python module
npe_add_module(point_cloud_utils
npe_add_module(pcu_internal
BINDING_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/sample.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/point_cloud_distance.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/lloyd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/meshio.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/mesh_utils.cpp)
target_include_directories(point_cloud_utils PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/vcglib)
target_include_directories(point_cloud_utils PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_include_directories(point_cloud_utils PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/nanoflann)
target_include_directories(point_cloud_utils PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external)
target_link_libraries(point_cloud_utils PRIVATE geogram)
set_target_properties(point_cloud_utils PROPERTIES COMPILE_FLAGS "-fvisibility=hidden -msse3")
target_include_directories(pcu_internal PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/vcglib)
target_include_directories(pcu_internal PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_include_directories(pcu_internal PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/nanoflann)
target_include_directories(pcu_internal PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external)
target_link_libraries(pcu_internal PRIVATE geogram)
set_target_properties(pcu_internal PROPERTIES COMPILE_FLAGS "-fvisibility=hidden -msse3")

if (${CMAKE_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(point_cloud_utils PUBLIC OpenMP::OpenMP_CXX)
target_link_libraries(pcu_internal PUBLIC OpenMP::OpenMP_CXX)
endif()

2 changes: 2 additions & 0 deletions point_cloud_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .pcu_internal import *
from .sinkhorn import *
Binary file added point_cloud_utils/__init__.pyc
Binary file not shown.
Binary file not shown.
Binary file added point_cloud_utils/pcu_internal.so
Binary file not shown.
121 changes: 121 additions & 0 deletions point_cloud_utils/sinkhorn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np

def pairwise_distances(a, b, p=2, squeeze=False):
"""
Compute the pairwise distance matrix between a and b which both have size [m, n, d] or [n, d]. The result is a tensor of
size [m, n, n] (or [n, n]) whose entry [m, i, j] contains the distance_tensor between a[m, i, :] and b[m, j, :].
:param a: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d]
:param b: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d]
:param p: Norm to use for the distance_tensor
:param squeeze: If set, any redundant dimensions will be squeezed in the output
:return: A tensor containing the pairwise distance_tensor between each pair of inputs in a batch.
"""

if len(a.shape) == 2 and len(b.shape) == 2:
a = a[np.newaxis, :, :]
b = b[np.newaxis, :, :]

if len(a.shape) != 3:
raise ValueError("Invalid shape for a. Must be [m, n, d] or [n, d] but got", a.shape)
if len(b.shape) != 3:
raise ValueError("Invalid shape for a. Must be [m, n, d] or [n, d] but got", b.shape)

ret = np.power(a[:, :, np.newaxis, :] - b[:, np.newaxis, :, :], p).sum(3)
if squeeze:
ret = np.squeeze(ret)

return ret


def chamfer(a, b):
"""
Compute the chamfer distance between two sets of vectors, a, and b
:param a: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d]
:param b: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d]
:return: A [m] shaped tensor storing the Chamfer distance between each minibatch entry
"""
M = pairwise_distances(a, b, squeeze=False)
print(M.shape)
return M.min(1).sum(1) + M.min(2).sum(1)


def sinkhorn(a, b, M, eps, max_iters=100, stop_thresh=1e-3):
"""
Compute the Sinkhorn divergence between two sum of dirac delta distributions, U, and V.
This implementation is numerically stable with float32.
:param a: A m-sized minibatch of weights for each dirac in the first distribution, U. i.e. shape = [m, n]
:param b: A m-sized minibatch of weights for each dirac in the second distribution, V. i.e. shape = [m, n]
:param M: A minibatch of n-by-n tensors storing the distance between each pair of diracs in U and V.
i.e. shape = [m, n, n] and each i.e. M[k, i, j] = ||u[k,_i] - v[k, j]||
:param eps: The reciprocal of the sinkhorn regularization parameter
:param max_iters: The maximum number of Sinkhorn iterations
:param stop_thresh: Stop if the change in iterates is below this value
:return:
"""
# a and b are tensors of size [m, n]
# M is a tensor of size [m, n, n]

M = np.squeeze(M)
a = np.squeeze(a)
b = np.squeeze(b)
squeezed = False
if len(M.shape) == 2 and len(a.shape) == 1 and len(b.shape) == 1:
M = M[np.newaxis, :, :]
a = a[np.newaxis, :]
b = b[np.newaxis, :]
squeezed = True

if len(M.shape) != 3:
raise ValueError("Got unexpected shape for M (%s), should be [nb, m, n] where nb is batch size, and "
"m and n are the number of samples in the two input measures." % str(M.shape))

nb = M.shape[0]
m = M.shape[1]
n = M.shape[2]

if a.dtype != b.dtype or a.dtype != M.dtype:
raise ValueError("Tensors a, b, and M must have the same dtype got: dtype(a) = %s, dtype(b) = %s, dtype(M) = %s"
% (str(a.dtype), str(b.dtype), str(M.dtype)))
if a.shape != (nb, m):
raise ValueError("Got unexpected shape for tensor a (%s). Expected [nb, m] where M has shape [nb, m, n]." %
str(a.shape))

if b.shape != (nb, n):
raise ValueError("Got unexpected shape for tensor b (%s). Expected [nb, n] where M has shape [nb, m, n]." %
str(b.shape))

u = np.zeros_like(a)
v = np.zeros_like(b)

M_t = np.transpose(M, axes=(0, 2, 1))

def stabilized_log_sum_exp(x):
max_x = x.max(2)
x = x - max_x[:, :, np.newaxis]
ret = np.log(np.sum(np.exp(x), axis=2)) + max_x
return ret

for current_iter in range(max_iters):
u_prev = u
v_prev = v

summand_u = (-M + np.expand_dims(v, 1)) / eps
u = eps * (np.log(a) - stabilized_log_sum_exp(summand_u))

summand_v = (-M_t + np.expand_dims(u, 1)) / eps
v = eps * (np.log(b) - stabilized_log_sum_exp(summand_v))

err_u = np.sum(np.abs(u_prev-u), axis=1).max()
err_v = np.sum(np.abs(v_prev-v), axis=1).max()

if err_u < stop_thresh and err_v < stop_thresh:
break

log_P = (-M + np.expand_dims(u, 2) + np.expand_dims(v, 1)) / eps

P = np.exp(log_P)

if squeezed:
P = np.squeeze(P)

return P
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def run(self):

def build_extension(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
extdir = os.path.join(extdir, "point_cloud_utils")
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, '-DPYTHON_EXECUTABLE=' + sys.executable]
cmake_args.extend(ext.cmake_args)

Expand Down

0 comments on commit 4412e71

Please sign in to comment.