Skip to content

Commit

Permalink
Memory optimized dists_add_symmetric
Browse files Browse the repository at this point in the history
  • Loading branch information
KushnirDmytro committed Apr 11, 2023
1 parent f526e29 commit 539672c
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions cosypose/lib3d/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ def dists_add(TXO_pred, TXO_gt, points):
dists = TXO_gt_points - TXO_pred_points
return dists


def dists_add_symmetric(TXO_pred, TXO_gt, points):
TXO_pred_points = transform_pts(TXO_pred, points)
TXO_gt_points = transform_pts(TXO_gt, points)
dists = TXO_gt_points.unsqueeze(1) - TXO_pred_points.unsqueeze(2)
dists_norm_squared = (dists ** 2).sum(dim=-1)
assign = dists_norm_squared.argmin(dim=1)
ids_row = torch.arange(dists.shape[0]).unsqueeze(1).repeat(1, dists.shape[1])
ids_col = torch.arange(dists.shape[1]).unsqueeze(0).repeat(dists.shape[0], 1)
dists = dists[ids_row, assign, ids_col]
return dists
distances = torch.cdist(TXO_gt_points, TXO_pred_points,
p=2, compute_mode='donot_use_mm_for_euclid_dist')
closest_points_idx = torch.argmin(distances, dim=2).squeeze()
TXO_pred_closest_to_gt = torch.index_select(TXO_pred_points, 1, closest_points_idx)
min_translations = TXO_gt_points - TXO_pred_closest_to_gt
return min_translations

0 comments on commit 539672c

Please sign in to comment.