skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.69k stars 383 forks source link

adding reference to model in unpack_data(batch)? #878

Closed cinntamani closed 1 year ago

cinntamani commented 1 year ago

Thank you for writing such an elegant framework for linking pytorch and sklearn.

When I work with skorch and try to extend it to cases where the dataset takes on unusual structure and unpacking data requires some additional information, I find the bottleneck to be the skorch.dataset.unpack_data(batch) function. Unlike other helper methods, this method does not pass a reference to the skorch model. I managed to add extra information to skorch model through init(), and those information will determine how to unpack the data. If there were a reference to the skorch model in unpack_data(batch), I could just override this function alone without changing much of skorch. This is just my suggestion for your consideration.

Thank you again!

BenjaminBossan commented 1 year ago

Thank you for explaining your issue so well and coming up with a possible solution. In general, I'm not opposed to the idea to have a self.unpack_data on the net, which by default just calls unpack_data. It would be possible to achieve a similar thing by patching skorch.dataset.unpack_data but that's a bit ugly.

Our main reason for not implementing this so far was that we thought if you have a case where unpack_data doesn't work, you probably have to override train_step_single et al anyway, so you don't need to use unpack_data. Apparently, that's not the case for you. Could you please explain what your case is, e.g. how you would solve the problem if you had reference to the net instance? That way, we can determine together if the proposed solution is the best one.

cinntamani commented 1 year ago

Thanks a lot for your quick response Benjamin!

The main reason that I propose to modify unpack_data() is because I hope to avoid modifying the source code of skorch, and only apply patches if necessary. With this goal in mind, if the interface of unpack_data() does not need change on the user's part, then I only need to patch this function, and override the reference in all the modules that call this function. However, if the interface of unpack_data() does need change on the user's part, then I also need to modify all the modules that call this function, and then further patch other modules that call the modified modules. This is the crux of the issue.

I agree that if unpack_data() needs changes, we also need to change the following methods

skorch.net.train_step_single()
skorch.net.validation_step()
skorch.net.evaluation_step()
skorch.net.infer()

However, their interface need not change (partly because each of these methods contain a reference to net), so I do not need to change the calling modules.

The actual problem I try to solve is as follows: I try to use skorch in combination with sklearn GridSearchCV. However, the data I work with has a natural grouping structure, where I need to split at the group level, but the modeling happens at the observation level. Furthermore, sklearn in general cannot handle torch Dataset, so I cannot hope to use grouped CV splitter. As a result, I can only pass the group index to GridSearchCV, and reconstruct the data inside skorch NeuralNet. In order to do this, I pass a function to NeuralNet constructor argument dataset, and another argument for the actual dataset so that this function can use it together with the group index passed from sklearn CV splitter to reconstruct the data.

There are cases where the Dataset contains more variables than X with a different shape and cannot be combined with X, and the reconstruction of the data depends on some additional overall task information as instruction of how to combine those variables (which can vary in the GridSearch process). The only way I can think of right now is to pass that overall task information as arguments when initializing skorch NeuralNet.

This issue can be handled mostly by modifying the following functions that call unpack_data():

skorch.net.train_step_single()
skorch.net.validation_step()
skorch.net.evaluation_step()
skorch.net.infer()

However, there is another call path of unpack_data() that does not go through these functions, which is skroch.scoring.loss_scoring(). This function is called if sklearn GridSearchCV tries to directly score the performance. The issue is that, inside loss_scoring(), unpack_data() is called again to obtain the true y, in addition to the evaluation_step() to get the predicted y. Hence, I can only hope to put the reconstruction logic inside unpack_data(), but there is no way to pass the additional overall task information in this call path. Yet, the function loss_scoring() calling unpack_data() does have a reference to NeuralNet available.

To summarize, if I hope to avoid modifying the source code of skorch, and only apply patches if necessary, what I can see right now is that, by adding a reference to skorch NeuralNet instance to unpack_data(), it will make it much easier to handle unusual Dataset, because arbitrary information can be passed through skorch NeuralNet instance. The reference is a placeholder to keep the API fixed if user wants to override unpack_data().

Thanks a lot for your attention! :)

BenjaminBossan commented 1 year ago

Thank you for taking the time to provide your detailed summary of the situation. This already shows that my initial idea

In general, I'm not opposed to the idea to have a self.unpack_data on the net, which by default just calls unpack_data

would not actually solve your problem.

Regarding the problem itself, my intuition tells me that changing unpack_data is not the right way to approach the problem. If I understand correctly, you are basically bypassing the CV split used by sklearn's grid search and instead perform the split inside NeuralNet. Presumably, if you didn't have to do that, you wouldn't need to add the groups to your data and thus wouldn't have the problem with unpack_data. So let's try to address this problem.

So you mentioned that

sklearn in general cannot handle torch Dataset, so I cannot hope to use grouped CV splitter

This is not completely true. We do provide a wrapper called SliceDataset which allows you to use datasets with grid search. Then you could use, say, GroupKFold from sklearn to have the grouped split. Let me show you some simple code below:

import numpy as np
from skorch.dataset import Dataset
from skorch import NeuralNetClassifier
from skorch.scoring import loss_scoring
from skorch.helper import SliceDataset
from sklearn.model_selection import GroupKFold, GridSearchCV
from sklearn.metrics import make_scorer
import torch
from torch import nn

# create an X that mirrors the groups, each group is either 0, 1, or 2
X = np.repeat(np.array([0, 1, 2, 0, 1, 2]), 5).reshape(2, 15).T.astype(np.float32)
print(X)
# shows
array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.]], dtype=float32)

groups = X[:, 0].copy()  # groups are same as X's 1st column
y = np.zeros(15, dtype=int)  # doesn't matter

# put the data into a SliceDataset
ds = Dataset(X, y)
X_slice_ds = SliceDataset(ds, idx=0)
y_slice_ds = SliceDataset(ds, idx=1)

# let's define a simple net that prints X
class ClassifierModule(nn.Module):
    def __init__(self):
        super(ClassifierModule, self).__init__()
        self.dense = nn.Linear(2, 3)
        self.sm = nn.Softmax(dim=-1)
    def forward(self, X, **kwargs):
        print(X)
        return self.sm(self.dense(X))

net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=1,
    lr=0.1,
    train_split=False,
)

# use grid search with GroupKFold and loss_scoring
earch = GridSearchCV(net, {'lr': [0.1, 0.2]}, scoring=loss_scoring, cv=GroupKFold(n_splits=3), refit=False)
search.fit(X_slice_ds, y_slice_ds, groups=groups)

# prints the following X's

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])

As you can see, the module was only ever called with 2 of the 3 groups for training (the 10 samples) and one of the 3 groups for validation (5 samples). Therefore, the group split was successful.

I hope it is clear what I tried to convey that it may help you solve your problem without having to change unpack_data.

cinntamani commented 1 year ago

Thank you so much for the detailed example Benjamin! I knew about SliceDict, but didn't notice SliceDataset. This is a really cool use of this class, and that indeed solved my problem. Thanks a lot!