diff --git a/src/main/java/ai/nets/samj/AbstractSamJ2.java b/src/main/java/ai/nets/samj/AbstractSamJ2.java index e4944df..45d909b 100644 --- a/src/main/java/ai/nets/samj/AbstractSamJ2.java +++ b/src/main/java/ai/nets/samj/AbstractSamJ2.java @@ -29,37 +29,28 @@ import java.util.Map; import java.util.Objects; -import ai.nets.samj.AbstractSamJ.DebugTextPrinter; import java.awt.Polygon; import java.awt.Rectangle; -import java.io.File; import java.io.IOException; import io.bioimage.modelrunner.apposed.appose.Environment; import io.bioimage.modelrunner.apposed.appose.Service; import io.bioimage.modelrunner.apposed.appose.Service.Task; import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; -import io.bioimage.modelrunner.numpy.DecodeNumpy; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.IterableInterval; -import net.imglib2.RandomAccessible; import net.imglib2.RandomAccessibleInterval; import net.imglib2.converter.Converter; import net.imglib2.converter.Converters; -import net.imglib2.converter.RealTypeConverters; -import net.imglib2.img.array.ArrayImgs; -import net.imglib2.loops.LoopBuilder; import net.imglib2.type.NativeType; import net.imglib2.type.Type; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.UnsignedByteType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Cast; -import net.imglib2.util.Intervals; import net.imglib2.util.Util; -import net.imglib2.view.IntervalView; import net.imglib2.view.Views; /** @@ -166,9 +157,12 @@ void updateImage(RandomAccessibleInterval rai) throws IOException, RuntimeExc private & NativeType> void addImage(RandomAccessibleInterval rai) throws IOException, RuntimeException, InterruptedException { - adaptImage(rai); + setImageOfInterest(rai); + if (img.dimensionsAsLongArray()[0] * img.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS + || img.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || img.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) + return; this.script = ""; - sendImgLib2AsNp(rai); + sendImgLib2AsNp(); createEncodeImageScript(); try { printScript(script, "Creation of initial embeddings"); @@ -191,9 +185,9 @@ else if (task.status == TaskStatus.CRASHED) } } - protected abstract & NativeType> void adaptImage(RandomAccessibleInterval rai); + protected abstract & NativeType> void setImageOfInterest(RandomAccessibleInterval rai); - protected abstract void createEncodeImageScript(); + protected abstract & NativeType> void createEncodeImageScript(); private void reencodeCrop() throws IOException, InterruptedException, RuntimeException { reencodeCrop(null); @@ -742,10 +736,9 @@ public void close() { if (python != null) python.close(); } - private & NativeType> - void sendImgLib2AsNp(RandomAccessibleInterval targetImg) { - shma = createEfficientSAMInputSHM(reescaleIfNeeded(targetImg)); - RealTypeConverters.copyFromTo( (RandomAccessible) this.img, shma.getSharedRAI() ); + protected & NativeType> + void sendImgLib2AsNp() { + createSHMArray((RandomAccessibleInterval) this.img); String code = ""; // This line wants to recreate the original numpy array. Should look like: // input0_appose_shm = shared_memory.SharedMemory(name=input0) @@ -766,12 +759,14 @@ void sendImgLib2AsNp(RandomAccessibleInterval targetImg) { code += "globals()['input_h'] = input_h" + System.lineSeparator(); code += "globals()['input_w'] = input_w" + System.lineSeparator(); code += "task.update(str(im.shape))" + System.lineSeparator(); - code += "im = torch.from_numpy(np.transpose(im.astype('float32')))" + System.lineSeparator(); + code += "im = torch.from_numpy(np.transpose(im))" + System.lineSeparator(); code += "task.update('after ' + str(im.shape))" + System.lineSeparator(); code += "im_shm.unlink()" + System.lineSeparator(); //code += "box_shm.close()" + System.lineSeparator(); this.script += code; } + + protected abstract & NativeType> void createSHMArray(RandomAccessibleInterval imShared); private & NativeType> void sendCropAsNp(long[] cropSize) { if (cropSize == null) @@ -786,9 +781,7 @@ else if (cropSize.length == 2) //RandomAccessibleInterval crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize); //RandomAccessibleInterval crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize); targetDims = crop.dimensionsAsLongArray(); - shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {targetDims[0], targetDims[1], targetDims[2]}, - Util.getTypeFromInterval(crop)); - RealTypeConverters.copyFromTo(crop, shma.getSharedRAI()); + createSHMArray(crop); String code = ""; // This line wants to recreate the original numpy array. Should look like: // input0_appose_shm = shared_memory.SharedMemory(name=input0) @@ -808,7 +801,7 @@ else if (cropSize.length == 2) //code += "np.save('/home/carlos/git/cropped.npy', im)" + System.lineSeparator(); code += "globals()['input_h'] = input_h" + System.lineSeparator(); code += "globals()['input_w'] = input_w" + System.lineSeparator(); - code += "im = torch.from_numpy(np.transpose(im.astype('float32')))" + System.lineSeparator(); + code += "im = torch.from_numpy(np.transpose(im))" + System.lineSeparator(); code += "im_shm.unlink()" + System.lineSeparator(); //code += "box_shm.close()" + System.lineSeparator(); this.script += code; @@ -817,37 +810,6 @@ else if (cropSize.length == 2) protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll); protected abstract void processBoxWithSAM(boolean returnAll); - - private & NativeType> - SharedMemoryArray createEfficientSAMInputSHM(final RandomAccessibleInterval inImg) { - long[] dims = inImg.dimensionsAsLongArray(); - if ((dims.length != 3 && dims.length != 2) || (dims.length == 3 && dims[2] != 3 && dims[2] != 1)){ - throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images." - + "The image dimensions order should be 'yxc', first dimension height, second width and third channels."); - } - return SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], 3}, - Cast.unchecked(Util.getTypeFromInterval(img))); - } - - private & NativeType> - void adaptImageToModel(final RandomAccessibleInterval ogImg, RandomAccessibleInterval targetImg) { - if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 3) { - for (int i = 0; i < 3; i ++) - RealTypeConverters.copyFromTo( normalizedView(Views.hyperSlice(ogImg, 2, i)), Views.hyperSlice(targetImg, 2, i) ); - } else if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 1) { - debugPrinter.printText("CONVERTED 1 CHANNEL IMAGE INTO 3 TO BE FEEDED TO SAMJ"); - IntervalView resIm = Views.interval( Views.expandMirrorDouble(normalizedView(ogImg), new long[] {0, 0, 2}), - Intervals.createMinMax(new long[] {0, 0, 0, ogImg.dimensionsAsLongArray()[0], ogImg.dimensionsAsLongArray()[1], 2}) ); - RealTypeConverters.copyFromTo( resIm, targetImg ); - } else if (ogImg.numDimensions() == 2) { - adaptImageToModel(Views.addDimension(ogImg, 0, 0), targetImg); - } else { - throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images." - + "The image dimensions order should be 'yxc', first dimension height, second width and third channels."); - } - this.img = targetImg; - this.targetDims = targetImg.dimensionsAsLongArray(); - } /** * Set an empty consumer as {@link DebugTextPrinter} to avoid the SAMJ model instance @@ -899,134 +861,6 @@ public void printScript(final String script, final String designationOfTheScript debugPrinter.printText(script); debugPrinter.printText("END: =========== "+designationOfTheScript+" ==========="); } - - /** - * Get the maximum and minimum pixel values of an {@link IterableInterval} - * @param - * the ImgLib2 data types that the {@link IterableInterval} can have - * @param inImg - * the {@link IterableInterval} from which the max and min values are going to be found - * @param outMinMax - * double array where the max and min values of the {@link IterableInterval} will be written - */ - public static & NativeType> - void getMinMaxPixelValue(final IterableInterval inImg, final double[] outMinMax) { - double min = inImg.firstElement().getRealDouble(); - double max = min; - - for (T px : inImg) { - double val = px.getRealDouble(); - min = Math.min(min,val); - max = Math.max(max,val); - } - - if (outMinMax.length > 1) { - outMinMax[0] = min; - outMinMax[1] = max; - } - } - - /** - * Whether the values in the length 2 array are between 0 and 1 - * @param inMinMax - * the interval to be evaluated - * @return true if the values are between 0 and 1 and false otherwise - */ - public static boolean isNormalizedInterval(final double[] inMinMax) { - return (inMinMax[0] >= 0 && inMinMax[0] <= 1 - && inMinMax[1] >= 0 && inMinMax[1] <= 1); - } - - /** - * Normalize the {@link RandomAccessibleInterval} with the position 0 of the inMimMax array as the min - * and the position 1 as the max - * @param - * the ImgLib2 data types that the {@link RandomAccessibleInterval} can have - * @param inImg - * {@link RandomAccessibleInterval} to be normalized - * @param inMinMax - * the values to which the {@link RandomAccessibleInterval} will be normalized. Should be a double array of length - * 2 with the smaller value at position 0 - * @return the normalized {@link RandomAccessibleInterval} - */ - private static & NativeType> - RandomAccessibleInterval normalizedView(final RandomAccessibleInterval inImg, final double[] inMinMax) { - final double min = inMinMax[0]; - final double range = inMinMax[1] - min; - return Converters.convert(inImg, (i, o) -> o.setReal((i.getRealFloat() - min) / (range + 1e-9)), new FloatType()); - } - - /** - * Checks the input RAI if its min and max pixel values are between [0,1]. - * If they are not, the RAI will be subject to {@link Converters#convert(RandomAccessibleInterval, Converter, Type)} - * with here-created Converter that knows how to bring the pixel values into the interval [0,1]. - * - * @param - * the ImgLib2 data types that the {@link RandomAccessibleInterval} can have - * @param inImg - * RAI to be potentially normalized. - * @return The input image itself or a View of it with {@link FloatType} data type - */ - public & NativeType> - RandomAccessibleInterval normalizedView(final RandomAccessibleInterval inImg) { - final double[] minMax = new double[2]; - getMinMaxPixelValue(Views.iterable(inImg), minMax); - ///debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS _NOT_ NORMALIZED, returning Converted view"); - //return normalizedView(inImg, minMax); - if (isNormalizedInterval(minMax) && Util.getTypeFromInterval(inImg) instanceof FloatType) { - debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS NORMALIZED, returning directly itself"); - return Cast.unchecked(inImg); - } else if (isNormalizedInterval(minMax)) { - debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS NORMALIZED, returning directly itself"); - return Converters.convert(inImg, (i, o) -> o.setReal(i.getRealFloat()), new FloatType()); - } else { - debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS _NOT_ NORMALIZED, returning Converted view"); - return normalizedView(inImg, minMax); - } - } - - private static & NativeType> - RandomAccessibleInterval convertViewToRGB(final RandomAccessibleInterval inImg, final double[] inMinMax) { - final double min = inMinMax[0]; - final double range = inMinMax[1] - min; - return Converters.convert(inImg, (i, o) -> o.setReal(255 * (i.getRealDouble() - min) / range), new UnsignedByteType()); - } - - /** - * Checks the input RAI if its min and max pixel values are between [0,255] and if it is of {@link UnsignedByteType} type. - * If they are not, the RAI will be subject to {@link Converters#convert(RandomAccessibleInterval, Converter, Type)} - * with here-created Converter that knows how to bring the pixel values into the interval [0,255]. - * - * @param inImg - * RAI to be potentially converted to RGB. - * @return The input image itself or a View of it in {@link UnsignedByteType} data type - */ - public & NativeType> - RandomAccessibleInterval convertViewToRGB(final RandomAccessibleInterval inImg) { - if (Util.getTypeFromInterval(inImg) instanceof UnsignedByteType) { - debugPrinter.printText("IMAGE IS RGB, returning directly itself"); - return Cast.unchecked(inImg); - } - final double[] minMax = new double[2]; - debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS _NOT_ RGB, returning Converted view"); - getMinMaxPixelValue(Views.iterable(inImg), minMax); - return convertViewToRGB(inImg, minMax); - } - - protected static & NativeType> RandomAccessibleInterval - reescaleIfNeeded(RandomAccessibleInterval rai) { - if ((rai.dimensionsAsLongArray()[0] > rai.dimensionsAsLongArray()[1]) - && (rai.dimensionsAsLongArray()[0] > MAX_ENCODED_AREA_RS)) { - // TODO reescale - return rai; - } else if ((rai.dimensionsAsLongArray()[0] < rai.dimensionsAsLongArray()[1]) - && (rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE)) { - // TODO reescale - return rai; - } else { - return rai; - } - } /** * Calculate the coordinates of the encoded image with respect to the coordinates @@ -1080,20 +914,4 @@ protected void recalculatePolys(List polys, long[] encodeCoords) { pp.ypoints = Arrays.stream(pp.ypoints).map(y -> y + (int) encodeCoords[1]).toArray(); }); } - - /** - * MEthod used during development to test features - * @param args - * nothing - * @throws IOException nothing - * @throws RuntimeException nothing - * @throws InterruptedException nothing - */ - public static void main(String[] args) throws IOException, RuntimeException, InterruptedException { - RandomAccessibleInterval img = ArrayImgs.unsignedBytes(new long[] {50, 50, 3}); - img = Views.addDimension(img, 1, 2); - try (AbstractSamJ2 sam = initializeSam(SamEnvManager.create(), img)) { - sam.processBox(new int[] {0, 5, 10, 26}); - } - } } diff --git a/src/main/java/ai/nets/samj/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/EfficientViTSamJ.java index a213f38..3e61ed7 100644 --- a/src/main/java/ai/nets/samj/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/EfficientViTSamJ.java @@ -19,35 +19,24 @@ */ package ai.nets.samj; -import java.lang.AutoCloseable; -import java.util.ArrayList; import java.util.HashMap; -import java.util.Iterator; import java.util.List; -import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; -import java.awt.Polygon; -import java.awt.Rectangle; import java.io.File; import java.io.IOException; import io.bioimage.modelrunner.apposed.appose.Environment; -import io.bioimage.modelrunner.apposed.appose.Service; import io.bioimage.modelrunner.apposed.appose.Service.Task; import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; -import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.RandomAccessibleInterval; import net.imglib2.converter.RealTypeConverters; import net.imglib2.img.array.ArrayImgs; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.UnsignedByteType; -import net.imglib2.util.Cast; import net.imglib2.util.Intervals; -import net.imglib2.util.Util; import net.imglib2.view.IntervalView; import net.imglib2.view.Views; @@ -206,7 +195,7 @@ else if (task.status == TaskStatus.CRASHED) try{ sam = new EfficientViTSamJ(manager, modelType, debugPrinter, printPythonCode); sam.encodeCoords = new long[] {0, 0}; - sam.addImage(image); + sam.updateImage(image); } catch (IOException | RuntimeException | InterruptedException ex) { if (sam != null) sam.close(); throw ex; @@ -239,7 +228,7 @@ else if (task.status == TaskStatus.CRASHED) try{ sam = new EfficientViTSamJ(manager, modelType); sam.encodeCoords = new long[] {0, 0}; - sam.addImage(image); + sam.updateImage(image); } catch (IOException | RuntimeException | InterruptedException ex) { if (sam != null) sam.close(); throw ex; @@ -301,86 +290,28 @@ else if (task.status == TaskStatus.CRASHED) initializeSam(SamEnvManager manager, RandomAccessibleInterval image) throws IOException, RuntimeException, InterruptedException { return initializeSam(SamEnvManager.DEFAULT_EVITSAM, manager, image); } - - /** - * Encode an image (n-dimensional array) with an EfficientViTSAM model - * @param - * ImgLib2 data type of the image of interest - * @param rai - * image (n-dimensional array) that is going to be encoded as a {@link RandomAccessibleInterval} - * @throws IOException if any of the files to run a Python process is missing - * @throws RuntimeException if there is any error running the Python code - * @throws InterruptedException if the process is interrupted - */ - private & NativeType> - void addImage(RandomAccessibleInterval rai) - throws IOException, RuntimeException, InterruptedException { - if (rai.dimensionsAsLongArray()[0] * rai.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS - || rai.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) { - this.targetDims = new long[] {0, 0, 0}; - this.img = rai; - return; - } - this.script = ""; - sendImgLib2AsNp(rai); - this.script += "" - + "task.update(str(im.shape))" + System.lineSeparator() - + "predictor.set_image(im)"; - try { - printScript(script, "Creation of initial embeddings"); - Task task = python.task(script); - task.waitFor(); - if (task.status == TaskStatus.CANCELED) - throw new RuntimeException(); - else if (task.status == TaskStatus.FAILED) - throw new RuntimeException(); - else if (task.status == TaskStatus.CRASHED) - throw new RuntimeException(); - this.shma.close(); - } catch (IOException | InterruptedException | RuntimeException e) { - try { - this.shma.close(); - } catch (IOException e1) { - throw new IOException(e.toString() + System.lineSeparator() + e1.toString()); - } - throw e; - } + + @Override + protected & NativeType> void setImageOfInterest(RandomAccessibleInterval rai) { + checkImageIsFine(rai); + long[] dims = rai.dimensionsAsLongArray(); + this.img = Views.interval(rai, new long[] {0, 0, 0}, new long[] {dims[0] - 1, dims[1] - 1, 2}); + this.targetDims = img.dimensionsAsLongArray(); } - - /** - * Method used that runs EfficientViTSAM using a mask as the prompt. The mask should be a 2D single-channel - * image {@link RandomAccessibleInterval} of the same x and y sizes as the image of interest, the image - * where the model is finding the segmentations. - * Note that the quality of this prompting method is not good, it is still experimental as it barely works - * - * @param - * ImgLib2 datatype of the mask - * @param img - * mask used as the prompt - * @param returnAll - * whether to return all the polygons created by EfficientSAM of only the biggest - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientViTSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public & NativeType> - List processMask(RandomAccessibleInterval img, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - long[] dims = img.dimensionsAsLongArray(); - if (dims.length == 2 && dims[1] == this.shma.getOriginalShape()[1] && dims[0] == this.shma.getOriginalShape()[0]) { - img = Views.permute(img, 0, 1); - } else if (dims.length != 2 && dims[0] != this.shma.getOriginalShape()[1] && dims[1] != this.shma.getOriginalShape()[0]) { - throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width " - + this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]); - } - SharedMemoryArray maskShma = SharedMemoryArray.buildSHMA(img); - try { - return processMask(maskShma, returnAll); - } catch (IOException | RuntimeException | InterruptedException ex) { - maskShma.close(); - throw ex; - } + + @Override + protected void createEncodeImageScript() { + this.script = "" + + "task.update(str(im.shape))" + System.lineSeparator() + + "predictor.set_image(im)"; + } + + @Override + protected & NativeType> void createSHMArray(RandomAccessibleInterval imShared) { + RandomAccessibleInterval imageToBeSent = ImgLib2SAMUtils.reescaleIfNeeded(imShared); + long[] dims = imageToBeSent.dimensionsAsLongArray(); + shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], dims[2]}, new UnsignedByteType()); + adaptImageToModel(imageToBeSent, shma.getSharedRAI()); } @Override @@ -478,24 +409,32 @@ protected void processBoxWithSAM(boolean returnAll) { this.script = code; } - private static & NativeType> - SharedMemoryArray createEfficientSAMInputSHM(final RandomAccessibleInterval inImg) { + private & NativeType> void checkImageIsFine(RandomAccessibleInterval inImg) { long[] dims = inImg.dimensionsAsLongArray(); - if ((dims.length != 3 && dims.length != 2) || (dims.length == 3 && dims[2] != 3 && dims[2] != 1)) { - throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, float32, ...) 2D images." - + "The image dimensions order should be 'xyc', first dimension height, second width and third channels."); + if ((dims.length != 3 && dims.length != 2) || (dims.length == 3 && dims[2] != 3 && dims[2] != 1)){ + throw new IllegalArgumentException("Currently EfficientViTSAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images." + + "The image dimensions order should be 'xyc', first dimension width, second height and third channels."); } - return SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], 3}, new UnsignedByteType()); + } + + /** + * + * @return the list of EfficientViTSAM models that are supported + */ + public static List getListOfSupportedEfficientViTSAM(){ + return MODELS_DICT.keySet().stream().collect(Collectors.toList()); } private & NativeType> - void adaptImageToModel(final RandomAccessibleInterval ogImg, RandomAccessibleInterval targetImg) { + void adaptImageToModel(RandomAccessibleInterval ogImg, RandomAccessibleInterval targetImg) { if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 3) { for (int i = 0; i < 3; i ++) - RealTypeConverters.copyFromTo( convertViewToRGB(Views.hyperSlice(ogImg, 2, i)), Views.hyperSlice(targetImg, 2, i) ); + RealTypeConverters.copyFromTo( ImgLib2SAMUtils.convertViewToRGB(Views.hyperSlice(ogImg, 2, i), this.debugPrinter), + Views.hyperSlice(targetImg, 2, i) ); } else if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 1) { debugPrinter.printText("CONVERTED 1 CHANNEL IMAGE INTO 3 TO BE FEEDED TO SAMJ"); - IntervalView resIm = Views.interval( Views.expandMirrorDouble(convertViewToRGB(ogImg), new long[] {0, 0, 2}), + IntervalView resIm = + Views.interval( Views.expandMirrorDouble(ImgLib2SAMUtils.convertViewToRGB(ogImg, this.debugPrinter), new long[] {0, 0, 2}), Intervals.createMinMax(new long[] {0, 0, 0, ogImg.dimensionsAsLongArray()[0], ogImg.dimensionsAsLongArray()[1], 2}) ); RealTypeConverters.copyFromTo( resIm, targetImg ); } else if (ogImg.numDimensions() == 2) { @@ -504,8 +443,6 @@ void adaptImageToModel(final RandomAccessibleInterval ogImg, RandomAccessible throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images." + "The image dimensions order should be 'yxc', first dimension height, second width and third channels."); } - this.img = targetImg; - this.targetDims = targetImg.dimensionsAsLongArray(); } /** @@ -523,24 +460,4 @@ public static void main(String[] args) throws IOException, RuntimeException, Int sam.processBox(new int[] {0, 5, 10, 26}); } } - - /** - * - * @return the list of EfficientViTSAM models that are supported - */ - public static List getListOfSupportedEfficientViTSAM(){ - return MODELS_DICT.keySet().stream().collect(Collectors.toList()); - } - - @Override - protected & NativeType> void adaptImage(RandomAccessibleInterval rai) { - // TODO Auto-generated method stub - - } - - @Override - protected void createEncodeImageScript() { - // TODO Auto-generated method stub - - } } diff --git a/src/main/java/ai/nets/samj/ImgLib2SAMUtils.java b/src/main/java/ai/nets/samj/ImgLib2SAMUtils.java new file mode 100644 index 0000000..b631e3b --- /dev/null +++ b/src/main/java/ai/nets/samj/ImgLib2SAMUtils.java @@ -0,0 +1,146 @@ +package ai.nets.samj; + +import ai.nets.samj.AbstractSamJ2.DebugTextPrinter; +import net.imglib2.IterableInterval; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.converter.Converter; +import net.imglib2.converter.Converters; +import net.imglib2.type.NativeType; +import net.imglib2.type.Type; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; +import net.imglib2.util.Util; +import net.imglib2.view.Views; + +public class ImgLib2SAMUtils { + + /** + * Get the maximum and minimum pixel values of an {@link IterableInterval} + * @param + * the ImgLib2 data types that the {@link IterableInterval} can have + * @param inImg + * the {@link IterableInterval} from which the max and min values are going to be found + * @param outMinMax + * double array where the max and min values of the {@link IterableInterval} will be written + */ + public static & NativeType> + void getMinMaxPixelValue(final IterableInterval inImg, final double[] outMinMax) { + double min = inImg.firstElement().getRealDouble(); + double max = min; + + for (T px : inImg) { + double val = px.getRealDouble(); + min = Math.min(min,val); + max = Math.max(max,val); + } + + if (outMinMax.length > 1) { + outMinMax[0] = min; + outMinMax[1] = max; + } + } + + /** + * Whether the values in the length 2 array are between 0 and 1 + * @param inMinMax + * the interval to be evaluated + * @return true if the values are between 0 and 1 and false otherwise + */ + public static boolean isNormalizedInterval(final double[] inMinMax) { + return (inMinMax[0] >= 0 && inMinMax[0] <= 1 + && inMinMax[1] >= 0 && inMinMax[1] <= 1); + } + + /** + * Normalize the {@link RandomAccessibleInterval} with the position 0 of the inMimMax array as the min + * and the position 1 as the max + * @param + * the ImgLib2 data types that the {@link RandomAccessibleInterval} can have + * @param inImg + * {@link RandomAccessibleInterval} to be normalized + * @param inMinMax + * the values to which the {@link RandomAccessibleInterval} will be normalized. Should be a double array of length + * 2 with the smaller value at position 0 + * @return the normalized {@link RandomAccessibleInterval} + */ + private static & NativeType> + RandomAccessibleInterval normalizedView(final RandomAccessibleInterval inImg, final double[] inMinMax) { + final double min = inMinMax[0]; + final double range = inMinMax[1] - min; + return Converters.convert(inImg, (i, o) -> o.setReal((i.getRealFloat() - min) / (range + 1e-9)), new FloatType()); + } + + /** + * Checks the input RAI if its min and max pixel values are between [0,1]. + * If they are not, the RAI will be subject to {@link Converters#convert(RandomAccessibleInterval, Converter, Type)} + * with here-created Converter that knows how to bring the pixel values into the interval [0,1]. + * + * @param + * the ImgLib2 data types that the {@link RandomAccessibleInterval} can have + * @param inImg + * RAI to be potentially normalized. + * @return The input image itself or a View of it with {@link FloatType} data type + */ + public static & NativeType> + RandomAccessibleInterval normalizedView(final RandomAccessibleInterval inImg, DebugTextPrinter debugPrinter) { + final double[] minMax = new double[2]; + getMinMaxPixelValue(Views.iterable(inImg), minMax); + ///debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS _NOT_ NORMALIZED, returning Converted view"); + //return normalizedView(inImg, minMax); + if (isNormalizedInterval(minMax) && Util.getTypeFromInterval(inImg) instanceof FloatType) { + debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS NORMALIZED, returning directly itself"); + return Cast.unchecked(inImg); + } else if (isNormalizedInterval(minMax)) { + debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS NORMALIZED, returning directly itself"); + return Converters.convert(inImg, (i, o) -> o.setReal(i.getRealFloat()), new FloatType()); + } else { + debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS _NOT_ NORMALIZED, returning Converted view"); + return normalizedView(inImg, minMax); + } + } + + private static & NativeType> + RandomAccessibleInterval convertViewToRGB(final RandomAccessibleInterval inImg, final double[] inMinMax) { + final double min = inMinMax[0]; + final double range = inMinMax[1] - min; + return Converters.convert(inImg, (i, o) -> o.setReal(255 * (i.getRealDouble() - min) / range), new UnsignedByteType()); + } + + /** + * Checks the input RAI if its min and max pixel values are between [0,255] and if it is of {@link UnsignedByteType} type. + * If they are not, the RAI will be subject to {@link Converters#convert(RandomAccessibleInterval, Converter, Type)} + * with here-created Converter that knows how to bring the pixel values into the interval [0,255]. + * + * @param inImg + * RAI to be potentially converted to RGB. + * @return The input image itself or a View of it in {@link UnsignedByteType} data type + */ + public static & NativeType> + RandomAccessibleInterval convertViewToRGB(final RandomAccessibleInterval inImg, DebugTextPrinter debugPrinter) { + if (Util.getTypeFromInterval(inImg) instanceof UnsignedByteType) { + debugPrinter.printText("IMAGE IS RGB, returning directly itself"); + return Cast.unchecked(inImg); + } + final double[] minMax = new double[2]; + debugPrinter.printText("MIN VALUE="+minMax[0]+", MAX VALUE="+minMax[1]+", IMAGE IS _NOT_ RGB, returning Converted view"); + getMinMaxPixelValue(Views.iterable(inImg), minMax); + return convertViewToRGB(inImg, minMax); + } + + protected static & NativeType> RandomAccessibleInterval + reescaleIfNeeded(RandomAccessibleInterval rai) { + if ((rai.dimensionsAsLongArray()[0] > rai.dimensionsAsLongArray()[1]) + && (rai.dimensionsAsLongArray()[0] > AbstractSamJ2.MAX_ENCODED_AREA_RS)) { + // TODO reescale + return rai; + } else if ((rai.dimensionsAsLongArray()[0] < rai.dimensionsAsLongArray()[1]) + && (rai.dimensionsAsLongArray()[1] > AbstractSamJ2.MAX_ENCODED_SIDE)) { + // TODO reescale + return rai; + } else { + return rai; + } + } +}