Open jordancaraballo opened 2 years ago
Hi Jordan!
As of v0.5.7:
The current logits/argmax functionality is not flexible and can definitely be improved. Moreover, your example highlights a limitation that I overlooked completely. I also use the library to tile images, feed tiles to semantic segmentation network and merge back to full result, but in my case those images don't have channel dimension as it's always just one value per pixel, so I never used Merger's logits/argmax functionality and Tiler's channel_dimension at the same time...
Merger's add
expects data with the same shape as tile_shape
of the original Tiler. Similarly add_batch
expects data of shape [batch, *tile_shape]
. In your example batch
variable is expected to be of shape (batch, 256, 256, 4)
.
If you specify logits
for Merger, it would change the expected shape to [logits, *tile_shape]
(or [batch, logits, *tile_shape]
). If we specify logits in your example (e.g. merger = Merger(tiler, logits=6)
), the expected data for add_batch
would become (batch, 6, 256, 256, 4)
, which is also not something that you want to happen.
I will try to find the time to implement this soon and sorry for not supporting your usecase yet!
Hi,
Thanks for your response! While this might not be ideal, I was able to work around the channel_dimension constraints by having a second Tiler object with the N channels that are supposed to be the output of the network. The following is an example of the implementation. If you are okay with this, I can create a pull request with a similar example so other users I point to this library can leverage it.
Binary segmentation problem where output is N x 256 x 256 x 1
mode = 'constant'
batch_size = 512
tiler_image = Tiler(
data_shape=image.shape,
tile_shape=(256, 256, 4),
channel_dimension=2,
overlap=0.50,
mode=mode,
)
tiler_mask = Tiler(
data_shape=image.shape,
tile_shape=(256, 256, 1),
channel_dimension=2,
overlap=0.50,
mode=mode,
)
new_shape, padding = tiler_image.calculate_padding()
tiler_image.recalculate(data_shape=new_shape)
tiler_mask.recalculate(data_shape=new_shape)
padded_image = np.pad(image, padding, mode=mm, constant_values=1200)
merger = Merger(tiler=tiler_mask, window="overlap-tile")
for batch_id, batch in tiler_image(padded_image, batch_size=batch_size):
batch = model.predict(batch)
merger.add_batch(batch_id, batch_size, batch)
prediction = merger.merge(extra_padding=padding, dtype=image.dtype)
prediction = np.squeeze(np.where(prediction > 0.5, 1, 0).astype(np.int16))
print(prediction.shape, prediction.min(), prediction.max())
The only challenge I am trying to work around now is the presence of artifact effects at the boundary level of non-uniform images (e.g. an image of size 90538x9148x4 where the tile size is 256x256 with a batch size of 512). Is this something you have worked around with this library? I can open a new issue with this topic as well. An example is illustrated below, where those vertical lines are not expected at the left border of the image.
Nice workaround! Hopefully in the near future it will not be needed anymore!
Not sure how helpful these suggestions are, but:
hamming
, it should apply more weight to center of tile, instead of weighting all tile pixels equallyI'm curious to hear if you manage to fix this :)
Is your feature request related to a problem? Please describe.
I have been trying to use this library for the inference of TensorFlow binary and multiclass segmentation models. I am able to use the tiler object to perform the predictions. I have not been able to figure out how to leverage the merger for the following cases.
data_shape = 5000 x 3000 x 4 tile_shape = (256, 256, 4) channel_dimension = 0
The output of the model can be either a batch of (N x 256 x 256 x 1) or (N x 256 x 256 x 6); where 6 is the number of classes.
Describe the solution you'd like Would be great to have additional examples regarding similar use cases performing TensorFlow or PyTorch predictions.
Here is an example of what I have been trying:
I am probably missing something, but would be nice to have it documented. Also, argmax option seems to be hardcoded for channel first images, which adds additional computational requirements when using channels last images. Any help would be appreciated.