skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.84k stars 388 forks source link

Error predicting when using `compile=True` with `NeuralNetBinaryClassifier` #1057

Closed foster999 closed 4 months ago

foster999 commented 4 months ago

I'm using Python 3.11.3, skorch==1.0.0

I find the error disappears when dropping the compile argument. It doesn't seem to error for similar examples with NeuralNetClassifier.

Minimal example

import numpy as np
import torch.nn.functional as F
from skorch import NeuralNetBinaryClassifier
from torch import nn

X = np.random.normal(size=(200, 100)).astype("float32")
y = np.zeros(200).astype("float32")
y[:100] = 1

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.linear = nn.Linear(100, 50)
        self.output = nn.Linear(50, 1)

    def forward(self, input):
        out = input
        out = self.linear(out)
        out = F.relu(out)
        out = self.output(out)
        return out.squeeze(-1)

net = NeuralNetBinaryClassifier(MyNet, max_epochs=1, compile=True)

net.fit(X, y)
y_proba = net.predict_proba(X)

Raises

Traceback (most recent call last):
  File "/home/jupyter/model_training_tracking/train_sklearn_single_nn.py", line 98, in <module>
    classifier.fit(
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/classifier.py", line 348, in fit
    return super().fit(X, y, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1319, in fit
    self.partial_fit(X, y, **fit_params)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1278, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1196, in fit_loop
    self.notify("on_epoch_end", **on_epoch_kwargs)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 386, in notify
    getattr(cb, method_name)(self, **cb_kwargs)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/callbacks/scoring.py", line 489, in on_epoch_end
    current_score = self._scoring(cached_net, X_test, y_test)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/callbacks/scoring.py", line 181, in _scoring
    return scorer(net, X_test, y_test)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 253, in __call__
    return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 345, in _score
    y_pred = method_caller(
             ^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 87, in _cached_call
    result, _ = _get_response_values(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/utils/_response.py", line 210, in _get_response_values
    y_pred = prediction_method(X)
             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/classifier.py", line 381, in predict
    return (y_proba[:, 1] > self.threshold).astype('uint8')
BenjaminBossan commented 4 months ago

Thanks for reporting this error and providing a reproducer. I could identify the issue and created a PR to fix it. In the meantime, you could use a normal NeuralNetClassifier with two classes, which works the same as NeuralNetBinaryClassifier for most practical purposes, and should have no issues with torch.compile.