skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.89k stars 391 forks source link

FIX NeuralNetBinaryClassifier with torch.compile #1058

Closed BenjaminBossan closed 6 months ago

BenjaminBossan commented 6 months ago

Fixes #1057

NeuralNetBinaryClassifier was not working with torch.compile because the non-linearity was not correctly inferred. This inference depends on the instance type of the criterion. However, when using torch.compile, the criterion is wrapped, resulting in the isinstance check to miss. Now, we unwrap the criterion before checking the instance type.