diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index 1132a19..a2f5066 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -398,8 +398,8 @@ public List processPoints(List pointsList, List pointsNeg Rectangle rect = new Rectangle(); rect.x = (int) this.encodeCoords[0]; rect.y = (int) this.encodeCoords[1]; - rect.height = (int) this.targetDims[0]; - rect.width = (int) this.targetDims[1]; + rect.height = (int) this.targetDims[1]; + rect.width = (int) this.targetDims[0]; return processPoints(pointsList, pointsNegList, rect, true); } @@ -432,8 +432,8 @@ public List processPoints(List pointsList, List pointsNeg Rectangle rect = new Rectangle(); rect.x = (int) this.encodeCoords[0]; rect.y = (int) this.encodeCoords[1]; - rect.height = (int) this.targetDims[0]; - rect.width = (int) this.targetDims[1]; + rect.height = (int) this.targetDims[1]; + rect.width = (int) this.targetDims[0]; return processPoints(pointsList, pointsNegList, rect, returnAll); } @@ -621,16 +621,16 @@ private ArrayList getPointsNotInRect(List pointsList, List } public Rectangle getCurrentlyEncodedArea() { - int xMargin = (int) (targetDims[1] * 0.1); - int yMargin = (int) (targetDims[0] * 0.1); + int xMargin = (int) (targetDims[0] * 0.1); + int yMargin = (int) (targetDims[1] * 0.1); Rectangle alreadyEncoded; if (encodeCoords[0] != 0 || encodeCoords[1] != 0 || targetDims[1] != this.img.dimensionsAsLongArray()[1] || targetDims[0] != this.img.dimensionsAsLongArray()[0]) { alreadyEncoded = new Rectangle((int) encodeCoords[0] + xMargin / 2, (int) encodeCoords[1] + yMargin / 2, - (int) targetDims[1] - xMargin, (int) targetDims[0] - yMargin); + (int) targetDims[0] - xMargin, (int) targetDims[1] - yMargin); } else { alreadyEncoded = new Rectangle((int) encodeCoords[0], (int) encodeCoords[1], - (int) targetDims[1], (int) targetDims[0]); + (int) targetDims[0], (int) targetDims[1]); } return alreadyEncoded; } @@ -655,7 +655,7 @@ private void evaluateReencodingNeeded(List pointsList, List points ArrayList notInRect = getPointsNotInRect(pointsList, pointsNegList, rect); if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x - && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y + && alreadyEncoded.height + alreadyEncoded.y >= rect.height + rect.y && alreadyEncoded.width * 0.9 < rect.width && alreadyEncoded.height * 0.9 < rect.height && notInRect.size() == 0) { return; @@ -664,13 +664,13 @@ private void evaluateReencodingNeeded(List pointsList, List points this.reencodeCrop(new long[] {rect.width, rect.height}); } else if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x - && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y + && alreadyEncoded.height + alreadyEncoded.y >= rect.height + rect.y && (alreadyEncoded.width * 0.9 > rect.width || alreadyEncoded.height * 0.9 > rect.height)) { this.encodeCoords = new long[] {rect.x, rect.y}; this.reencodeCrop(new long[] {rect.width, rect.height}); } else if (alreadyEncoded.x <= neededArea.x && alreadyEncoded.y <= neededArea.y && alreadyEncoded.width + alreadyEncoded.x >= neededArea.width + neededArea.x - && alreadyEncoded.height + alreadyEncoded.y >= neededArea.width + neededArea.y + && alreadyEncoded.height + alreadyEncoded.y >= neededArea.height + neededArea.y && alreadyEncoded.width * 0.9 < neededArea.width && alreadyEncoded.height * 0.9 < neededArea.height && notInRect.size() == 0) { return; @@ -770,8 +770,8 @@ public boolean isAreaEncoded(int[] boundingBox) { public boolean needsMoreResolution(int[] boundingBox) { long xSize = boundingBox[2] - boundingBox[0]; long ySize = boundingBox[3] - boundingBox[1]; - long encodedX = targetDims[1]; - long encodedY = targetDims[0]; + long encodedX = targetDims[0]; + long encodedY = targetDims[1]; if (xSize * LOWER_REENCODE_THRESH < encodedX && ySize * LOWER_REENCODE_THRESH < encodedY) return true; return false; @@ -785,8 +785,8 @@ public boolean needsMoreResolution(int[] boundingBox) { public boolean boundingBoxTooBig(int[] boundingBox) { long xSize = boundingBox[2] - boundingBox[0]; long ySize = boundingBox[3] - boundingBox[1]; - long encodedX = targetDims[1]; - long encodedY = targetDims[0]; + long encodedX = targetDims[0]; + long encodedY = targetDims[1]; if (xSize * UPPER_REENCODE_THRESH > encodedX && ySize * UPPER_REENCODE_THRESH > encodedY) return true; return false; diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index cff6b80..1aecef3 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -218,7 +218,7 @@ protected void createEncodeImageScript() { code += "globals()['input_h'] = input_h" + System.lineSeparator(); code += "globals()['input_w'] = input_w" + System.lineSeparator(); code += "task.update(str(im.shape))" + System.lineSeparator(); - code += "im = torch.from_numpy(np.transpose(im))" + System.lineSeparator(); + code += "im = torch.from_numpy(np.transpose(im, (2, 1, 0)))" + System.lineSeparator(); code += "task.update('after ' + str(im.shape))" + System.lineSeparator(); code += "im_shm.unlink()" + System.lineSeparator(); //code += "box_shm.close()" + System.lineSeparator(); diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 08446bc..5bb9183 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -313,11 +313,13 @@ protected void createEncodeImageScript() { + ")" + System.lineSeparator(); int size = 1; for (long l : targetDims) {size *= l;} - script += "im = np.ndarray(" + size + ", dtype='" + "', buffer=im_shm.buf).reshape(["; + script += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI())) + + "', buffer=im_shm.buf).reshape(["; for (long ll : targetDims) script += ll + ", "; script = script.substring(0, script.length() - 2); script += "])" + System.lineSeparator(); + script += "im = np.transpose(im, (1, 0, 2))" + System.lineSeparator(); //code += "np.save('/home/carlos/git/aa.npy', im)" + System.lineSeparator(); script += "im_shm.unlink()" + System.lineSeparator(); //code += "box_shm.close()" + System.lineSeparator();