Skip to content

Commit

Permalink
Finished linear interpolation with auto contrast
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorTatarnikov committed Jan 29, 2024
1 parent 46cf08a commit c3fa6bb
Show file tree
Hide file tree
Showing 2 changed files with 349 additions and 59 deletions.
305 changes: 296 additions & 9 deletions mesospim_stitcher/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def fuse_to_zarr(
):
print("Fusing to zarr")

chunk_shape: Tuple[int, ...] = (128, 256, 256)
chunk_shape: Tuple[int, ...] = (64, 128, 128)
tiles = []

for child in group:
Expand Down Expand Up @@ -242,6 +242,12 @@ def fuse_to_zarr(
{"name": "x", "type": "space", "unit": "micrometer"},
]

overlaps = calculate_overlaps(translations)
adjust_contrast(intensity_scale_factors, tiles)
interpolate_overlaps(
overlaps, tiles, slice_attributes, tile_names, translations
)

store = zarr.NestedDirectoryStore(str(output_path))
root = zarr.group(store=store)
compressor = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE)
Expand All @@ -261,10 +267,10 @@ def fuse_to_zarr(
x_s, x_e, y_s, y_e, z_s, z_e = translations[i]
curr_tile = tiles[i]
channel_idx = int(slice_attributes[tile_names[i]]["channel"])
if intensity_scale_factors[i] != 1.0:
curr_tile = da.multiply(
curr_tile, intensity_scale_factors[i], dtype=np.float16
).astype(np.int16)
# if intensity_scale_factors[i] != 1.0:
# curr_tile = da.multiply(
# curr_tile, intensity_scale_factors[i], dtype=np.float16
# ).astype(np.int16)

if num_channels > 1:
fused_image_store[
Expand All @@ -287,10 +293,10 @@ def fuse_to_zarr(
downsampled_image = downscale_nearest(
prev_resolution, (1, 1, 2, 2)
)
chunk_shape = (1, 64, 128, 128)
chunk_shape = (1, 32, 64, 64)
else:
downsampled_image = downscale_nearest(prev_resolution, (1, 2, 2))
chunk_shape = (64, 128, 128)
chunk_shape = (32, 64, 64)

downsampled_shape = downsampled_image.shape
downsampled_store = root.require_dataset(
Expand Down Expand Up @@ -527,9 +533,290 @@ def write_ome_zarr(
}


def calculate_overlaps(translations):
overlaps = {}

for i in range(len(translations) - 1):
curr_translation = translations[i]
for j in range(i + 1, len(translations)):
next_translation = translations[j]

if (curr_translation[1] > next_translation[0]) and (
curr_translation[3] > next_translation[2]
):
x_overlap_s = max(curr_translation[0], next_translation[0])
x_overlap_e = min(curr_translation[1], next_translation[1])
y_overlap_s = max(curr_translation[2], next_translation[2])
y_overlap_e = min(curr_translation[3], next_translation[3])
z_overlap_s = max(curr_translation[4], next_translation[4])
z_overlap_e = min(curr_translation[5], next_translation[5])

overlaps[(i, j)] = (
x_overlap_s,
x_overlap_e,
y_overlap_s,
y_overlap_e,
z_overlap_s,
z_overlap_e,
)

return overlaps


def calculate_scale_factors(
overlaps, images, slice_attributes, tile_names, translations, percentile
):
# nodes_visited = set()
num_tiles = len(images)
scale_factors = np.ones((num_tiles, num_tiles))
z_size, y_size, x_size = images[0].shape

for i in range(num_tiles):
channel_idx = int(slice_attributes[tile_names[i]]["channel"])
for j in range(i + 1, num_tiles):
if (i, j) in overlaps:
other_channel_idx = int(
slice_attributes[tile_names[j]]["channel"]
)
if channel_idx == other_channel_idx:
(
x_overlap_s,
x_overlap_e,
y_overlap_s,
y_overlap_e,
z_overlap_s,
z_overlap_e,
) = overlaps[(i, j)]
x_overlap = x_overlap_e - x_overlap_s
y_overlap = y_overlap_e - y_overlap_s
z_overlap = z_overlap_e - z_overlap_s

if translations[i][0] < translations[j][0]:
i_x_s = x_size - x_overlap
i_x_e = x_size
j_x_s = 0
j_x_e = x_overlap
else:
i_x_s = 0
i_x_e = x_overlap
j_x_s = x_size - x_overlap
j_x_e = x_size

if translations[i][2] < translations[j][2]:
i_y_s = y_size - y_overlap
i_y_e = y_size
j_y_s = 0
j_y_e = y_overlap
else:
i_y_s = 0
i_y_e = y_overlap
j_y_s = y_size - y_overlap
j_y_e = y_size

if translations[i][4] < translations[j][4]:
i_z_s = z_size - z_overlap
i_z_e = z_size
j_z_s = 0
j_z_e = z_overlap
else:
i_z_s = 0
i_z_e = z_overlap
j_z_s = z_size - z_overlap
j_z_e = z_size

i_image = images[i][i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e]
j_image = images[j][j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e]

median_i = np.percentile(i_image.ravel(), percentile)
median_j = np.percentile(j_image.ravel(), percentile)

curr_scale_factor = (median_i / median_j).compute()
scale_factors[i][j] = curr_scale_factor[0]

del i_image
del j_image
del median_i
del median_j

images[j] = np.multiply(
images[j], curr_scale_factor, dtype=np.float16
).astype(np.int16)

return scale_factors, images


def adjust_contrast(intensity_scale_factors, images):
num_tiles = len(images)

for i in range(num_tiles):
if intensity_scale_factors[i] != 1.0:
images[i] = da.multiply(
images[i], intensity_scale_factors[i], dtype=np.float16
).astype(np.int16)


def interpolate_overlaps(
overlaps, images, slice_attributes, tile_names, translations
):
num_tiles = len(images)
z_size, y_size, x_size = images[0].shape

for i in range(num_tiles):
channel_idx = int(slice_attributes[tile_names[i]]["channel"])
for j in range(i + 1, num_tiles):
if (i, j) in overlaps:
other_channel_idx = int(
slice_attributes[tile_names[j]]["channel"]
)
if channel_idx == other_channel_idx:
(
x_overlap_s,
x_overlap_e,
y_overlap_s,
y_overlap_e,
z_overlap_s,
z_overlap_e,
) = overlaps[(i, j)]
x_overlap = x_overlap_e - x_overlap_s
y_overlap = y_overlap_e - y_overlap_s
z_overlap = z_overlap_e - z_overlap_s

if translations[i][0] < translations[j][0]:
i_x_s = x_size - x_overlap
i_x_e = x_size
j_x_s = 0
j_x_e = x_overlap
else:
i_x_s = 0
i_x_e = x_overlap
j_x_s = x_size - x_overlap
j_x_e = x_size

if translations[i][2] < translations[j][2]:
i_y_s = y_size - y_overlap
i_y_e = y_size
j_y_s = 0
j_y_e = y_overlap
else:
i_y_s = 0
i_y_e = y_overlap
j_y_s = y_size - y_overlap
j_y_e = y_size

if translations[i][4] < translations[j][4]:
i_z_s = z_size - z_overlap
i_z_e = z_size
j_z_s = 0
j_z_e = z_overlap
else:
i_z_s = 0
i_z_e = z_overlap
j_z_s = z_size - z_overlap
j_z_e = z_size

if x_overlap / x_size < 0.2 and y_overlap / y_size < 0.2:
# Skip the small diagonal overlaps
# continue
x_lin = np.linspace(1, 0, x_overlap)
y_lin = np.linspace(1, 0, y_overlap)

# 1 at 0, 0 linearly decreasing to 0 at 1,
# 1 along the diagonal
yx_grid = np.outer(y_lin, x_lin)

if translations[i][0] < translations[j][0]:
decreasing_image = images[i][
i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e
]
increasing_image = images[j][
j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e
]

if translations[i][2] > translations[j][2]:
yx_grid = np.flip(yx_grid, 0)
else:
decreasing_image = images[j][
j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e
]
increasing_image = images[i][
i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e
]

# Check the direction of the diagonal.
# Flip the grid if the diagonal is going from top right
# to bottom left. Left to right is the default,
# this is when both ys and xs of i are either
# smaller or larger than j.
# If the two comparisons don't match (XOR) flip grid.
if (translations[i][0] < translations[j][0]) != (
translations[i][2] < translations[j][2]
):
if translations[i][2] < translations[j][2]:
yx_grid = np.flip(yx_grid, 0)

elif x_overlap / x_size < 0.2:
x_lin = np.linspace(1, 0, x_overlap)

# 1 in the first column,
# linearly decreasing to 0 in the last column
yx_grid = np.tile(x_lin, (y_overlap, 1))

if translations[i][0] < translations[j][0]:
decreasing_image = images[i][
i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e
]
increasing_image = images[j][
j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e
]
else:
decreasing_image = images[j][
j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e
]
increasing_image = images[i][
i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e
]
else:
y_lin = np.linspace(1, 0, y_overlap)

# 1 in the first row,
# linearly decreasing to 0 in the last row
yx_grid = np.tile(y_lin, (x_overlap, 1)).T

if translations[i][2] < translations[j][2]:
decreasing_image = images[i][
i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e
]
increasing_image = images[j][
j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e
]
else:
decreasing_image = images[j][
j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e
]
increasing_image = images[i][
i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e
]

interp = (
np.multiply(
decreasing_image.compute(),
yx_grid,
dtype=np.float16,
)
+ np.multiply(
increasing_image.compute(),
1 - yx_grid,
dtype=np.float16,
)
).astype(np.int16)

images[i][i_z_s:i_z_e, i_y_s:i_y_e, i_x_s:i_x_e] = interp
images[j][j_z_s:j_z_e, j_y_s:j_y_e, j_x_s:j_x_e] = interp

print(f"Done interpolating tile {i} and {j}")


def calculate_image_stats(image: da) -> Tuple[float, float, float, float]:
# min_intensity = image.min().compute()
# max_intensity = image.max().compute()
num_pixels = image.shape[0] * image.shape[1] * image.shape[2]
raveled_image = image.ravel()
top_min = raveled_image.topk(-int(num_pixels * 0.00001)).compute()
Expand Down
Loading

0 comments on commit c3fa6bb

Please sign in to comment.