BloodAxe / pytorch-toolbelt

PyTorch extensions for fast R&D prototyping and Kaggle farming
MIT License
1.52k stars 122 forks source link

How to Implement TTA For binary segmentation #95

Closed chefkrym closed 1 year ago

chefkrym commented 1 year ago

Anyone kind enough to share a code on how to use TTA for binary segmentation using this code? I have my trained model weights but can't figure out how to use Pytroch toolbelt.

Thank you.

BloodAxe commented 1 year ago

Assuming you have the model that takes [B, Cin, H, W] and outputs logits a single tensor of shape [B, Cout, H, W], where Cout is 1 for binary segmentation, but can be Cout>1 as well.

from pytorch_toolbelt.inference import tta

model = nn.Sequential(model, nn.Sigmoid()) # Apply sigmoid activation to logits predictions
model = tta.GeneralizedTTA(model, augment_fn=tta.fliplr_image_augment, deaugment_fn=tta.fliplr_image_deaugment)

You may or may not need to apply a sigmoid externally if your model already does it. Here it's more for a reference. After wrapping your model into tta.GeneralizedTTA that's it. You simply run inference as you normally would and TTA would be done for you inside. This class is even torch jit traceable, so you can export this model if you need.

chefkrym commented 1 year ago

Thank you so very much @BloodAxe I am really really grateful. I'm rather new to this and unsure as to how to wrap my model. I share my code below and request for your kind guidance? I wanted to apply TTA to just one test image (last block of cells of my code). Thank you sir.

https://colab.research.google.com/drive/1xJ62lFGlaVpbw6WPFfZCAR8_IKX4jclR?usp=sharing

BloodAxe commented 1 year ago

The best way to learn new things is try out understanding how it works. Take a look at the code, corresponding tests which should give you nice intuition how it works and how to use it in your case.