I have pushed a pretty big rework of Merger and it has three new/updated keywords now (ignore_channels: bool = False, logits_n: Optional[int] = None, logits_dim: int = 0).
Here's an example of how I imagine it all can be used. It's significantly more flexible, but maybe the API is a bit too complex now.
@jordancaraballo, please take a look, what do you think? Am I missing anything in your opinion? Otherwise in the next commits I will fix tests and make sure I didn't break anything else.
import numpy as np
from tiler import Tiler, Merger
# Let's say you have an image of size 5000x3000 pixels and 4 channels in the last dimension
image_shape = (5000, 3000, 4)
image_channel_dimension = -1
# and you want to tile them into tiles of 256x256 pixels and 4 channels in the last dimension
tile_shape = (256, 256, 4)
tile_overlap = 0.5
# to feed into a segmentation network with 10 output classes (in the last dimension) and batches of 128 tiles
# (so the network output has shape of (128, 256, 256, 10))
output_classes = 10
output_classes_dim = -1
batch_size = 128
image = np.random.rand(*image_shape)
tiler = Tiler(
data_shape=image_shape,
tile_shape=tile_shape,
channel_dimension=image_channel_dimension,
)
merger = Merger(
tiler,
ignore_channels=True, # this allows to "turn off" channels from Tiler
logits_n=output_classes, # this specifies how many logits/segmentation classes there will be
logits_dim=output_classes_dim, # and in which dimension
)
print("Processing batches...")
for batch_id, batch in tiler(image, batch_size=batch_size):
print(f"\tBatch: #{batch_id}, with data of shape {batch.shape}")
# simulating network output of shape (128, 256, 256, 10)
output = np.random.rand(batch_size, *tile_shape[:-1], output_classes)
print(f"\tWe simulate NN output with shape of {output.shape} and add it to Merger")
# adding output into Merger
merger.add_batch(batch_id, batch_size, output)
print("Processing finished.")
raw_merge_result = merger.merge(argmax=None, unpad=False)
print(f"Shape of the raw merge result: {raw_merge_result.shape}") # (5120, 3072, 10)
unpad_merge_result = merger.merge(argmax=None, unpad=True)
print(f"Shape of the unpad merge result: {unpad_merge_result.shape}") # (5000, 3000, 10)
argmaxed_merge_result = merger.merge(argmax=output_classes_dim, unpad=True)
print(f"Shape of the argmaxed merge result: {argmaxed_merge_result.shape}") # (5000, 3000)
Resolves #20
I have pushed a pretty big rework of Merger and it has three new/updated keywords now (
ignore_channels: bool = False, logits_n: Optional[int] = None, logits_dim: int = 0
).Here's an example of how I imagine it all can be used. It's significantly more flexible, but maybe the API is a bit too complex now.
@jordancaraballo, please take a look, what do you think? Am I missing anything in your opinion? Otherwise in the next commits I will fix tests and make sure I didn't break anything else.