Skip to content

Commit

Permalink
keep developing the python method
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jun 4, 2024
1 parent 89e22a9 commit 8ac7b4e
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions src/main/java/ai/nets/samj/models/PythonMethods.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ public class PythonMethods {


protected static String SAM_EVERYTHING = ""
+ "def calculate_pairs(masks):\n"
+ " added_masks = masks.sum(2)\n"
+ " inds = np.where(added_masks > 1)\n"
+ " pairs = np.zeros((0, 2))\n"
+ " for ii in range(inds.shape[0]):\n"
+ " overlapping = np.where(masks[inds[0][ii], inds[1][ii]])[0]\n"
+ " for i in range(overlapping[0].shape[0]):\n"
+ " for j in range(i + 1, overlapping[0].shape[0]):\n"
+ " pp = np.unique(np.array([overlapping[i], overlapping[j]])).reshape(-1, 2)\n"
+ " matches = np.all(pairs == pp, axis=1)\n"
+ " if not matches.any():\n"
+ " pairs = np.concatenate((pairs, pp), axis=0)\n"
+ " return pairs\n"
+ "\n"
+ "def sam_everything(point_list, return_all=False):\n"
+ " masks = np.zeros((input_h, input_w, 0), dtype='uint8')\n"
+ " scores = []\n"
Expand All @@ -125,24 +139,28 @@ public class PythonMethods {
+ " if predicted_iou[0] > 1:\n"
+ " masks = np.concatenate((masks, mask.reshape(mask.shape[0], mask.shape[1], 1)), axis=1)\n"
+ " scores.append(predicted_iou[0].cpu().detach().numpy())\n"
+ " added_masks = masks.sum(2)\n"
+ " inds = np.where(added_masks > 1)\n"
+ "\n"
+ " # TODO do we support detection of objects within objects?\n"
+ " while inds[0].shape[0] > 0:\n"
+ " overlapping = np.where(masks[inds[0][0], inds[1][0]])[0]\n"
+ " mask_sum = (masks[:, :, overlapping[0]] + masks[:, :, overlapping[1]])\n"
+ " pairs = calculate_pairs(masks)\n"
+ " while pairs.shape[0] != 0:\n"
+ " pp = pairs[0]\n"
+ " mask_sum = (masks[:, :, pp[0]] + masks[:, :, pp[1]])\n"
+ " union = (mask_sum > 0).sum()\n"
+ " intersec = (mask_sum == 2).sum()\n"
+ " if intersec / union > 0.8:\n"
+ " new_mask = (mask_sum > 0) * 1\n"
+ " masks = np.delete(masks, overlapping[:2], axis=2)\n"
+ " masks = np.delete(masks, interest, axis=2)\n"
+ " masks = np.concatenate((masks, new_mask.reshape(masks.shape[0], masks.shape[1], 1)), axis=2)\n"
+ " elif intersec == masks[:, :, overlapping[0]].sum() or intersec == masks[:, :, overlapping[1]].sum():\n"
+ " pair += [[overlapping[0], overlapping[1]]]\n"
+ " pairs = calculate_pairs(masks)\n"
+ " elif intersec == masks[:, :, pp[0]].sum() or intersec == masks[:, :, pp[1]].sum():\n"
+ " pairs = np.concatenate((pairs, np.unique(np.array([pp[0], pp[1]])).reshape(-1, 2)), axis=0)\n"
+ " else:\n"
+ " ## TODO run again precition\n"
+ " if score[overlapping[0]] > score[overlapping[1]]:\n"
+ " (masks[:, :, overlapping[1]] = ((masks[:, :, overlapping[1]] - masks[:, :, overlapping[0]]) > 0) * 1\n"
+ " if score[pp[0]] > score[pp[1]]:\n"
+ " (masks[:, :, pp[1]] = ((masks[:, :, pp[1]] - masks[:, :, pp[0]]) > 0) * 1\n"
+ " else:\n"
+ " (masks[:, :, pp[0]] = ((masks[:, :, pp[0]] - masks[:, :, pp[1]]) > 0) * 1\n"
+ " pairs = calculate_pairs(masks)\n"
+ " added_masks = masks.sum(2)\n"
+ " inds = np.where(added_masks > 1)\n"
+ " label = np.arange(masks.shape[2])\n"
Expand Down

0 comments on commit 8ac7b4e

Please sign in to comment.