the-lay / tiler

N-dimensional NumPy array tiling and merging with overlapping, padding and tapering
https://the-lay.github.io/tiler/
MIT License
65 stars 10 forks source link

[IDEA] Figuring out how to use this library with TensorFlow multi-class and binary classification #20

Open jordancaraballo opened 2 years ago

jordancaraballo commented 2 years ago

Is your feature request related to a problem? Please describe.

I have been trying to use this library for the inference of TensorFlow binary and multiclass segmentation models. I am able to use the tiler object to perform the predictions. I have not been able to figure out how to leverage the merger for the following cases.

data_shape = 5000 x 3000 x 4 tile_shape = (256, 256, 4) channel_dimension = 0

The output of the model can be either a batch of (N x 256 x 256 x 1) or (N x 256 x 256 x 6); where 6 is the number of classes.

ValueError: Passed data shape ([256 256   1]) does not fit expected tile shape ((256, 256, 4)).

Describe the solution you'd like Would be great to have additional examples regarding similar use cases performing TensorFlow or PyTorch predictions.

Here is an example of what I have been trying:

model = tf.keras.models.load_model(model.hdf5)

image = rxr.open_rasterio(filename)
image = image.transpose("y", "x", "band")
print(image.shape)

tiler = Tiler(
            data_shape=image.shape,
            tile_shape=(256, 256, 4),
            channel_dimension=2,
            #overlap=0.50
        )

# Calculate and apply extra padding, as well as adjust tiling parameters
#new_shape, padding = tiler.calculate_padding()
#tiler.recalculate(data_shape=new_shape)
#padded_image = np.pad(image, padding, mode="reflect")

merger = Merger(tiler=tiler)#, window="overlap-tile")
print(tiler)

for batch_id, batch in tiler(image, batch_size=512):
    batch = model.predict(batch)
    merger.add_batch(batch_id, 512, batch)

I am probably missing something, but would be nice to have it documented. Also, argmax option seems to be hardcoded for channel first images, which adds additional computational requirements when using channels last images. Any help would be appreciated.

the-lay commented 2 years ago

Hi Jordan!

As of v0.5.7:

The current logits/argmax functionality is not flexible and can definitely be improved. Moreover, your example highlights a limitation that I overlooked completely. I also use the library to tile images, feed tiles to semantic segmentation network and merge back to full result, but in my case those images don't have channel dimension as it's always just one value per pixel, so I never used Merger's logits/argmax functionality and Tiler's channel_dimension at the same time...

Merger's add expects data with the same shape as tile_shape of the original Tiler. Similarly add_batch expects data of shape [batch, *tile_shape]. In your example batch variable is expected to be of shape (batch, 256, 256, 4).

If you specify logits for Merger, it would change the expected shape to [logits, *tile_shape] (or [batch, logits, *tile_shape]). If we specify logits in your example (e.g. merger = Merger(tiler, logits=6)), the expected data for add_batch would become (batch, 6, 256, 256, 4), which is also not something that you want to happen.

I will try to find the time to implement this soon and sorry for not supporting your usecase yet!

jordancaraballo commented 2 years ago

Hi,

Thanks for your response! While this might not be ideal, I was able to work around the channel_dimension constraints by having a second Tiler object with the N channels that are supposed to be the output of the network. The following is an example of the implementation. If you are okay with this, I can create a pull request with a similar example so other users I point to this library can leverage it.

Binary segmentation problem where output is N x 256 x 256 x 1

mode = 'constant'
batch_size = 512

tiler_image = Tiler(
    data_shape=image.shape,
    tile_shape=(256, 256, 4),
    channel_dimension=2,
    overlap=0.50,
    mode=mode,
)

tiler_mask = Tiler(
    data_shape=image.shape,
    tile_shape=(256, 256, 1),
    channel_dimension=2,
    overlap=0.50,
    mode=mode,
)

new_shape, padding = tiler_image.calculate_padding()
tiler_image.recalculate(data_shape=new_shape)
tiler_mask.recalculate(data_shape=new_shape)
padded_image = np.pad(image, padding, mode=mm, constant_values=1200)

merger = Merger(tiler=tiler_mask, window="overlap-tile")

for batch_id, batch in tiler_image(padded_image, batch_size=batch_size):
    batch = model.predict(batch)
    merger.add_batch(batch_id, batch_size, batch)

prediction = merger.merge(extra_padding=padding, dtype=image.dtype)
prediction = np.squeeze(np.where(prediction > 0.5, 1, 0).astype(np.int16))
print(prediction.shape, prediction.min(), prediction.max())

The only challenge I am trying to work around now is the presence of artifact effects at the boundary level of non-uniform images (e.g. an image of size 90538x9148x4 where the tile size is 256x256 with a batch size of 512). Is this something you have worked around with this library? I can open a new issue with this topic as well. An example is illustrated below, where those vertical lines are not expected at the left border of the image.

Screen Shot 2022-04-26 at 08 48 12
the-lay commented 2 years ago

Nice workaround! Hopefully in the near future it will not be needed anymore!

Not sure how helpful these suggestions are, but:

I'm curious to hear if you manage to fix this :)