skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.84k stars 388 forks source link

Add (torchvision) transform wrapper #543

Open ottonemo opened 4 years ago

ottonemo commented 4 years ago

As discussed in https://github.com/skorch-dev/skorch/issues/524 it is desirable to be able to tune the transform parameters of dataset transformations like the ones from torchvision using parameter searches. For this we should provide a wrapper as shown in https://github.com/skorch-dev/skorch/issues/524#issuecomment-533883353.

Usage example:

mnist_train = torchvision.MNIST(...)
net = NeuralNetClassifier(...)
model = Pipeline([
    ('transform', ComposeTransformer([
        torchvision.transforms.RandomResizedCrop,
        torchvision.transforms.RandomHorizontalFlip,
        torchvision.transforms.ToTensor,
        torchvision.transforms.Normalize
    ]),
    ('clf', net),
])

model.set_params(transform__RandomResizedCrop__size=224)
model.fit(mnist_train, mnist_y)

It would be nice to support different modalities, not only vision but it is not a strict requirement. If it is very complicated to make a general wrapper, we should opt for supporting vision at first and search for a general solution later on (one reason being that pytorch audio and text are in flux right now).

BenjaminBossan commented 4 years ago

A couple of problems with the proposed solution: