diff --git a/cosypose/lib3d/distances.py b/cosypose/lib3d/distances.py index 308b573..cc8343e 100644 --- a/cosypose/lib3d/distances.py +++ b/cosypose/lib3d/distances.py @@ -8,14 +8,12 @@ def dists_add(TXO_pred, TXO_gt, points): dists = TXO_gt_points - TXO_pred_points return dists - def dists_add_symmetric(TXO_pred, TXO_gt, points): TXO_pred_points = transform_pts(TXO_pred, points) TXO_gt_points = transform_pts(TXO_gt, points) - dists = TXO_gt_points.unsqueeze(1) - TXO_pred_points.unsqueeze(2) - dists_norm_squared = (dists ** 2).sum(dim=-1) - assign = dists_norm_squared.argmin(dim=1) - ids_row = torch.arange(dists.shape[0]).unsqueeze(1).repeat(1, dists.shape[1]) - ids_col = torch.arange(dists.shape[1]).unsqueeze(0).repeat(dists.shape[0], 1) - dists = dists[ids_row, assign, ids_col] - return dists + distances = torch.cdist(TXO_gt_points, TXO_pred_points, + p=2, compute_mode='donot_use_mm_for_euclid_dist') + closest_points_idx = torch.argmin(distances, dim=2).squeeze() + TXO_pred_closest_to_gt = torch.index_select(TXO_pred_points, 1, closest_points_idx) + min_translations = TXO_gt_points - TXO_pred_closest_to_gt + return min_translations