Skip to content

Commit

Permalink
keep iterating
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 30, 2024
1 parent 7d41972 commit 1ae2e84
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 22 deletions.
4 changes: 3 additions & 1 deletion src/main/java/ai/nets/samj/gui/MainGUI.java
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ private void toggleDrawer() {
}

private < T extends RealType< T > & NativeType< T > > void batchSAMize() throws IOException, RuntimeException, InterruptedException {
RandomAccessibleInterval<T> rai = this.consumer.getFocusedImageAsRai();
RandomAccessibleInterval<T> rai = null;
if (this.consumer.getFocusedImage() != this.cmbImages.getSelectedObject())
rai = this.consumer.getFocusedImageAsRai();
List<int[]> pointPrompts = this.consumer.getPointRoisOnFocusImage();
List<Rectangle> rectPrompts = this.consumer.getRectRoisOnFocusImage();
if (pointPrompts.size() == 0 && rectPrompts.size() == 0 && !(rai.getType() instanceof IntegerType)){
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ public interface DebugTextPrinter { void printText(String text); }

protected abstract void cellSAM(List<int[]> grid, boolean returnAll);

protected abstract void
processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll);
protected abstract void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll);

protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll);

Expand Down Expand Up @@ -436,13 +435,14 @@ List<Mask> processBatchOfPrompts(List<int[]> pointsList, List<Rectangle> rects,
rectPrompts = rects.stream().map(rr -> new int[] {rr.x, rr.y, rr.x + rr.width, rr.y + rr.height})
.collect(Collectors.toList());
inputs.put("rect_prompts", rectPrompts);
processPromptsBatchWithSAM(pointsList.size(), rects.size(), maskShma, returnAll);
processPromptsBatchWithSAM(maskShma, returnAll);
printScript(script, "Batch of prompts inference");
List<Mask> polys = processAndRetrieveContours(null);
List<Mask> polys = processAndRetrieveContours(inputs);
recalculatePolys(polys, encodeCoords);
return polys;
} catch (IOException | RuntimeException | InterruptedException ex) {
maskShma.close();
if (maskShma != null)
maskShma.close();
throw ex;
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
*/
package ai.nets.samj.models;

import java.awt.Rectangle;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -387,7 +386,7 @@ public String deleteEncodingScript(String encodingName) {
}

@Override
protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr,
protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
boolean returnAll) {
String code = "";
code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator();
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import ai.nets.samj.install.EfficientViTSamEnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;

import java.awt.Rectangle;
import java.io.File;
import java.io.IOException;

Expand Down Expand Up @@ -462,7 +461,7 @@ public String deleteEncodingScript(String encodingName) {
}

@Override
protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr,
protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
boolean returnAll) {
String code = "";
code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator();
Expand Down
30 changes: 18 additions & 12 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -446,19 +446,23 @@ public String deleteEncodingScript(String encodingName) {
}

@Override
protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll) {
String code = "";
code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator();
code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape([";
for (long l : shmArr.getOriginalShape())
code += l + ",";
code += "])" + System.lineSeparator();
protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll) {
String code = ""
+ "num_features = 0" + System.lineSeparator();
if (shmArr != null) {
code += ""
+ "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator()
+ "mask_batch = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape([";
for (long l : shmArr.getOriginalShape())
code += l + ",";
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
}
code += ""
+ "point_prompts = []" + System.lineSeparator()
+ "point_labels = []" + System.lineSeparator()
+ "mask_prompts = []" + System.lineSeparator()
+ "mask_labels = []" + System.lineSeparator()
+ "labeled_array, num_features = label(mask)" + System.lineSeparator()
+ "contours_x = []" + System.lineSeparator()
+ "contours_y = []" + System.lineSeparator()
+ "rle_masks = []" + System.lineSeparator()
Expand All @@ -483,7 +487,7 @@ protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryA
+ " contours_y += c_Y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "for p_prompt in range(" + npoints + "):" + System.lineSeparator()
+ "for p_prompt in point_prompts:" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=np.array(p_prompt).reshape(1, 2)," + System.lineSeparator()
+ " point_labels=np.array([1])," + System.lineSeparator()
Expand All @@ -496,7 +500,7 @@ protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryA
+ " rle_masks += r_m" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "for rect_prompt in range(" + nrects + "):" + System.lineSeparator()
+ "for rect_prompt in rect_prompts:" + System.lineSeparator()
+ " input_box = np.array([[input_box[0], input_box[1]], [input_box[2], input_box[3]]])" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=None," + System.lineSeparator()
Expand All @@ -519,8 +523,10 @@ protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryA
+ "task.outputs['contours_y'] = contours_y" + System.lineSeparator()
+ "task.outputs['rle'] = rle_masks" + System.lineSeparator();
code += "mask = 0" + System.lineSeparator();
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
if (shmArr != null) {
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
}
this.script = code;
}
}

0 comments on commit 1ae2e84

Please sign in to comment.