fepegar / torchio

Medical imaging toolkit for deep learning
https://torchio.org
Apache License 2.0
2.07k stars 240 forks source link

Feature proposal: TorchIO hub - A system to store and fetch transform objects for reproducibility #972

Open tomvars opened 2 years ago

tomvars commented 2 years ago

🚀 Feature

Introducing a public TorchIO hub where researchers can save the transform object used to randomly sample and augment their data during training with one line of code - transform = tio.from_hub("cool_recent_paper")

Motivation

DL researchers and practitioners hoping to reproduce other people's work can easily fetch model weights and architectural definitions (e.g Torch Hub or MONAI Bundle), training parameters (e.g AutoModel from HuggingFace) and preprocessing strategies (e.g AutoFeatureExtractor from HuggingFace) however, one thing which is still an obstacle in reproducing someone's setup in a few lines of code is data augmentation. Libraries like Albumentations and TorchIO provide a variety of common data augmentation strategies - but they lack the Hub features of HF or Pytorch to easily store and fetch strategies.

Pitch

Not sure how you would implement this. As an MVP you could have a separate repo where users submit model transforms as code and a big dictionary lookup between some chosen string and their transforms.

fepegar commented 2 years ago

Hi, @tomvars. Thanks for the proposal. I think this is an excellent idea.

One potential solution would be leveraging the PyTorch Hub tools. I got this code working. What do you think?

import torchio as tio
fpg = tio.datasets.FPG()
fpg.plot(reorient=False)

Figure_1

import torch
repo = 'fepegar/resseg:add-preprocessing-hubconf'
function_name = 'get_preprocessing_transform'
input_path = fpg.t1.path
preprocess = torch.hub.load(repo, function_name, input_path, image_name='t1', force_reload=True)
preprocessed = preprocess(fpg)
preprocessed.plot(reorient=False)

Figure_2

tomvars commented 2 years ago

I really like this API! You could maybe create a new repo like fepegar/torchiohub:main and have a single hubconf.py file as the access point to different preprocessing functions. In the repo users could append their transform functions to a large transforms.py file and the hubconf.py would have lines such as from transforms import ronneberger_unet_2015_transform

fepegar commented 2 years ago

I think it's more convenient to allow users to use their own hubconf in their repos because

  1. This is what PyTorch does, so people are familiar with the syntax etc.
  2. Sometimes, getting a transform needs some special code. The snippet I shared is an example in which additional libraries or files might be needed just to compute the transform, and we wouldn't want to put everyone's code in the same repo.

So the contribution to this library (which I'm happy to write) would be documentation on how to set up transforms for reproducibility on top of PyTorch Hub. Does that sound good?

tomvars commented 2 years ago

That makes sense 👍 thoughts on introducing a class method for the Transform called from_hub which would wrap the torch.hub.load call and pass in the relevant arguments?

fepegar commented 2 years ago

You mean something like this?

@classmethod
def from_hub(cls, *args, **kwargs):
    return torch.hub.load(*args, **kwargs)
fepegar commented 1 year ago

Hey, I forgot to share some experiments I conducted. The code below needs unet to be PIP-installed:

import torch
import torchio as tio

colin = tio.datasets.Colin27()
path = colin.t1.path
torch.hub.load('fepegar/resseg:add-preprocessing-hubconf', 'get_preprocessing_transform', path)
transform = torch.hub.load('fepegar/resseg:add-preprocessing-hubconf', 'get_preprocessing_transform', path, image_name='t1')
transform(colin).plot()

Here, HistogramStandardization makes it a bit awkward, but things work. We should write a tutorial about this. If you think the class method would be helpful, feel free to contribute with a PR!