Open tomvars opened 2 years ago
Hi, @tomvars. Thanks for the proposal. I think this is an excellent idea.
One potential solution would be leveraging the PyTorch Hub tools. I got this code working. What do you think?
import torchio as tio
fpg = tio.datasets.FPG()
fpg.plot(reorient=False)
import torch
repo = 'fepegar/resseg:add-preprocessing-hubconf'
function_name = 'get_preprocessing_transform'
input_path = fpg.t1.path
preprocess = torch.hub.load(repo, function_name, input_path, image_name='t1', force_reload=True)
preprocessed = preprocess(fpg)
preprocessed.plot(reorient=False)
I really like this API! You could maybe create a new repo like fepegar/torchiohub:main
and have a single hubconf.py
file as the access point to different preprocessing functions. In the repo users could append their transform functions to a large transforms.py
file and the hubconf.py
would have lines such as from transforms import ronneberger_unet_2015_transform
I think it's more convenient to allow users to use their own hubconf in their repos because
So the contribution to this library (which I'm happy to write) would be documentation on how to set up transforms for reproducibility on top of PyTorch Hub. Does that sound good?
That makes sense 👍 thoughts on introducing a class method for the Transform called from_hub
which would wrap the torch.hub.load
call and pass in the relevant arguments?
You mean something like this?
@classmethod
def from_hub(cls, *args, **kwargs):
return torch.hub.load(*args, **kwargs)
Hey, I forgot to share some experiments I conducted. The code below needs unet
to be PIP-installed:
import torch
import torchio as tio
colin = tio.datasets.Colin27()
path = colin.t1.path
torch.hub.load('fepegar/resseg:add-preprocessing-hubconf', 'get_preprocessing_transform', path)
transform = torch.hub.load('fepegar/resseg:add-preprocessing-hubconf', 'get_preprocessing_transform', path, image_name='t1')
transform(colin).plot()
Here, HistogramStandardization
makes it a bit awkward, but things work. We should write a tutorial about this. If you think the class method would be helpful, feel free to contribute with a PR!
🚀 Feature
Introducing a public TorchIO hub where researchers can save the transform object used to randomly sample and augment their data during training with one line of code -
transform = tio.from_hub("cool_recent_paper")
Motivation
DL researchers and practitioners hoping to reproduce other people's work can easily fetch model weights and architectural definitions (e.g Torch Hub or MONAI Bundle), training parameters (e.g AutoModel from HuggingFace) and preprocessing strategies (e.g AutoFeatureExtractor from HuggingFace) however, one thing which is still an obstacle in reproducing someone's setup in a few lines of code is data augmentation. Libraries like Albumentations and TorchIO provide a variety of common data augmentation strategies - but they lack the Hub features of HF or Pytorch to easily store and fetch strategies.
Pitch
Not sure how you would implement this. As an MVP you could have a separate repo where users submit model transforms as code and a big dictionary lookup between some chosen string and their transforms.