Open amine759 opened 4 months ago
Thanks for proposing the PR. A few points:
CrossEntropyLoss
should already work with NeuralNetClassifier
.CrossEntropyLoss
, is it?@BenjaminBossan Hi, Thanks for your reply,
As far as I understood it seems we're restricted to using only the NLLLoss
loss function with skorch.NeuralNetClassifier
. When attempting to use any other loss function, like CrossEntropyLoss
, we encounter an error indicating that the criterion
parameter only accepts NLLLoss
.
Whereas the optimizer we can use any other than the default SGD
. I assumed this is because get_loss
does not handle the instance of a torch.nn.CrossEntropyLoss
and since skorch don't use type annotation, I see now it wasn't necessary in the first place. I could have just changed get_loss
and that's it, right?
When attempting to use any other loss function, like
CrossEntropyLoss
, we encounter an error indicating that thecriterion
parameter only acceptsNLLLoss
.
Could you please provide an example to reproduce this error? CE should absolutely work in skorch as is, e.g.:
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super().__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.output(X)
return X
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
criterion=nn.CrossEntropyLoss(),
lr=0.1,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)
net.fit(X, y)
y_proba = net.predict_proba(X)
This got a bit stale :)
@amine759 is this something you're still working on?
I had to use
CrossEntropyLoss
for my use case, I thought of creating this PR sinceNLLLoss
is the only loss function supported :).