Skip to content

Commit

Permalink
correct bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 28, 2024
1 parent f9f6243 commit 1840577
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
14 changes: 4 additions & 10 deletions src/main/java/ai/nets/samj/annotation/Mask.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package ai.nets.samj.annotation;

import java.awt.Polygon;
import java.awt.Rectangle;
import java.util.Arrays;
import java.util.List;

Expand All @@ -41,16 +40,13 @@ public class Mask {

private final long[] rleEncoding;

private final Rectangle crop;

private Mask(Polygon contour, long[] rleEncoding, Rectangle crop) {
private Mask(Polygon contour, long[] rleEncoding) {
this.contour = contour;
this.rleEncoding = rleEncoding;
this.crop = crop;
}

public static Mask build(Polygon contour, long[] rleEncoding, Rectangle crop) {
return new Mask(contour, rleEncoding, crop);
public static Mask build(Polygon contour, long[] rleEncoding) {
return new Mask(contour, rleEncoding);
}

public Polygon getContour() {
Expand All @@ -77,9 +73,7 @@ public static RandomAccessibleInterval<UnsignedByteType> getMask(long width, lon
for (Mask mask : masks) {
long[] rle = mask.getRLEMask();
for (int i = 0; i < rle.length; i += 2) {
int cropStartx = mask.crop.x;
int cropStarty = mask.crop.y;
int start = (int) (width * (cropStarty + i / 2) + cropStartx + mask.getRLEMask()[i]);
int start = (int) mask.getRLEMask()[i];
int len = (int) mask.getRLEMask()[i+ 1];
Arrays.fill(arr, start, start + len, (byte) 1);
}
Expand Down
16 changes: 10 additions & 6 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ else if (task.outputs.get("contours_x") == null)
throw new RuntimeException();
else if (task.outputs.get("contours_y") == null)
throw new RuntimeException();
else if (task.outputs.get("rles") == null)
else if (task.outputs.get("rle") == null)
throw new RuntimeException();
results = task.outputs;
} catch (IOException | InterruptedException | RuntimeException e) {
Expand All @@ -393,7 +393,7 @@ else if (task.outputs.get("rles") == null)
int[] xArr = contours_x.next().stream().mapToInt(Number::intValue).toArray();
int[] yArr = contours_y.next().stream().mapToInt(Number::intValue).toArray();
long[] rle = rles.next().stream().mapToLong(Number::longValue).toArray();
masks.add(Mask.build(new Polygon(xArr, yArr, xArr.length), rle, cropRect));
masks.add(Mask.build(new Polygon(xArr, yArr, xArr.length), rle));
}
return masks;
}
Expand Down Expand Up @@ -1034,15 +1034,19 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize)
* to detect small objects compared to the size of the whole image, SAMJ might encode crops of
* the total image, thus the coordinates of the polygons obtained need to be shifted in order
* to match the original image.
* @param polys
* polys obtained by SAMJ on the encoded crop
* @param masks
* masks obtained by SAMJ on the encoded crop
* @param encodeCoords
* position of the crop in the total image
*/
protected void recalculatePolys(List<Mask> polys, long[] encodeCoords) {
polys.stream().forEach(pp -> {
protected void recalculatePolys(List<Mask> masks, long[] encodeCoords) {
masks.stream().forEach(pp -> {
pp.getContour().xpoints = Arrays.stream(pp.getContour().xpoints).map(x -> x + (int) encodeCoords[0]).toArray();
pp.getContour().ypoints = Arrays.stream(pp.getContour().ypoints).map(y -> y + (int) encodeCoords[1]).toArray();
for (int i = 0; i < pp.getRLEMask().length; i += 2) {
pp.getRLEMask()[i] = encodeCoords[0] + pp.getRLEMask()[i] % this.targetDims[0]
+ (((int) (pp.getRLEMask()[i] / this.targetDims[0])) + encodeCoords[1]) * this.targetDims[0];
}
});
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/ai/nets/samj/models/PythonMethods.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public class PythonMethods {
+ " rle = encode_rle(obj.image)" + System.lineSeparator()
+ " bbox_w = obj.bbox[3] - obj.bbox[1]" + System.lineSeparator()
+ " for i in range(0, len(rle), 2):" + System.lineSeparator()
+ " rle[i] = sam_result.shape[1] * (obj.bbox[0] - 1 + rle[i] // bbox_w) + obj.bbox[1] + rle[i] % bbox_w" + System.lineSeparator()
+ " rle[i] = sam_result.shape[1] * (obj.bbox[0] + rle[i] // bbox_w) + obj.bbox[1] + rle[i] % bbox_w" + System.lineSeparator()
+ " rles.append(rle)" + System.lineSeparator()
+ " x_contours.append(x_coords)" + System.lineSeparator()
+ " y_contours.append(y_coords)" + System.lineSeparator()
Expand Down

0 comments on commit 1840577

Please sign in to comment.