From 2480c697a2d36f0330ad0025ad354b6884dc430c Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Thu, 25 Apr 2024 14:25:43 +0200 Subject: [PATCH] improve the formatting of images for SAM --- .../ai/nets/samj/models/AbstractSamJ.java | 29 ++++++++++++++----- .../ai/nets/samj/models/EfficientSamJ.java | 7 ++++- .../ai/nets/samj/models/EfficientViTSamJ.java | 7 ++++- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index 5528425..9ea993b 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -464,7 +464,7 @@ public List processPoints(List pointsList, List pointsNeg Objects.requireNonNull(encodingArea, "Third argument cannot be null. Use the method " + "'processPoints(List pointsList, List pointsNegList, Rectangle zoomedArea, boolean returnAll)'" + " instead"); - + boolean providedReencodingArea = true; if (encodingArea.x == -1) { encodingArea = getCurrentlyEncodedArea(); } else { @@ -690,8 +690,13 @@ private void evaluateReencodingNeeded(List pointsList, List points } else if (notInRect.size() != 0) { this.encodeCoords = new long[] {(long) Math.max(0, rect.x - rect.width * 0.1), (long) Math.max(0, rect.y - rect.height * 0.1)}; - this.reencodeCrop(new long[] {(long) Math.min(rect.width * 0.2, img.dimensionsAsLongArray()[0] - encodeCoords[0]), - (long) Math.min(rect.height * 0.2, img.dimensionsAsLongArray()[1] - encodeCoords[1])}); + long[] imgDims = this.img.dimensionsAsLongArray(); + long width = Math.min(imgDims[0], Math.max(rect.x + rect.width, neededArea.x + neededArea.width) - encodeCoords[0]); + long height = Math.min(imgDims[1], Math.max(rect.y + rect.height, neededArea.y + neededArea.height) - encodeCoords[1]); + if (alreadyEncoded.x == encodeCoords[0] && alreadyEncoded.y == encodeCoords[1] + && alreadyEncoded.width == width && alreadyEncoded.height == height) + return; + this.reencodeCrop(new long[] {width, height}); } else if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x && alreadyEncoded.height + alreadyEncoded.y >= rect.height + rect.y @@ -699,8 +704,13 @@ private void evaluateReencodingNeeded(List pointsList, List points && rect.contains(neededArea)) { this.encodeCoords = new long[] {(long) Math.max(0, rect.x - rect.width * 0.1), (long) Math.max(0, rect.y - rect.height * 0.1)}; - this.reencodeCrop(new long[] {(long) Math.min(rect.width * 0.2, img.dimensionsAsLongArray()[0] - encodeCoords[0]), - (long) Math.min(rect.height * 0.2, img.dimensionsAsLongArray()[1] - encodeCoords[1])}); + long[] imgDims = this.img.dimensionsAsLongArray(); + long width = Math.min(imgDims[0], Math.max(rect.x + rect.width, neededArea.x + neededArea.width) - encodeCoords[0]); + long height = Math.min(imgDims[1], Math.max(rect.y + rect.height, neededArea.y + neededArea.height) - encodeCoords[1]); + if (alreadyEncoded.x == encodeCoords[0] && alreadyEncoded.y == encodeCoords[1] + && alreadyEncoded.width == width && alreadyEncoded.height == height) + return; + this.reencodeCrop(new long[] {width, height}); } else if (alreadyEncoded.x <= neededArea.x && alreadyEncoded.y <= neededArea.y && alreadyEncoded.width + alreadyEncoded.x >= neededArea.width + neededArea.x && alreadyEncoded.height + alreadyEncoded.y >= neededArea.height + neededArea.y @@ -709,8 +719,13 @@ private void evaluateReencodingNeeded(List pointsList, List points return; } else { this.encodeCoords = new long[] {Math.min(rect.x, neededArea.x), Math.min(rect.y, neededArea.y)}; - this.reencodeCrop(new long[] {Math.max(rect.x + rect.width, neededArea.x + neededArea.width) - encodeCoords[0], - Math.max(rect.y + rect.height, neededArea.y + neededArea.height) - encodeCoords[1]}); + long[] imgDims = this.img.dimensionsAsLongArray(); + long width = Math.min(imgDims[0], Math.max(rect.x + rect.width, neededArea.x + neededArea.width) - encodeCoords[0]); + long height = Math.min(imgDims[1], Math.max(rect.y + rect.height, neededArea.y + neededArea.height) - encodeCoords[1]); + if (alreadyEncoded.x == encodeCoords[0] && alreadyEncoded.y == encodeCoords[1] + && alreadyEncoded.width == width && alreadyEncoded.height == height) + return; + this.reencodeCrop(new long[] {width, height}); } } diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index ae7c5ca..b734654 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -391,7 +391,12 @@ public static void main(String[] args) throws IOException, RuntimeException, Int protected & NativeType> void setImageOfInterest(RandomAccessibleInterval rai) { checkImageIsFine(rai); long[] dims = rai.dimensionsAsLongArray(); - this.img = Views.interval(rai, new long[] {0, 0, 0}, new long[] {dims[0] - 1, dims[1] - 1, 2}); + if (dims.length == 2) + rai = Views.addDimension(rai, 0, 0); + if (dims[2] == 1) + rai = Views.interval( Views.expandMirrorDouble(rai, new long[] {0, 0, 2}), + Intervals.createMinMax(new long[] {0, 0, 0, dims[0] - 1, dims[1] - 1, 2}) ); + this.img = rai; this.targetDims = img.dimensionsAsLongArray(); } diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 611f8a4..5104ae1 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -301,7 +301,12 @@ else if (task.status == TaskStatus.CRASHED) protected & NativeType> void setImageOfInterest(RandomAccessibleInterval rai) { checkImageIsFine(rai); long[] dims = rai.dimensionsAsLongArray(); - this.img = Views.interval(rai, new long[] {0, 0, 0}, new long[] {dims[0] - 1, dims[1] - 1, 2}); + if (dims.length == 2) + rai = Views.addDimension(rai, 0, 0); + if (dims[2] == 1) + rai = Views.interval( Views.expandMirrorDouble(rai, new long[] {0, 0, 2}), + Intervals.createMinMax(new long[] {0, 0, 0, dims[0] - 1, dims[1] - 1, 2}) ); + this.img = rai; this.targetDims = img.dimensionsAsLongArray(); }