Closed jpgard closed 1 year ago
Note: also makes a similar change for FTTransformer.
Thank you for the proposal! In fact, this limitation was introduced intentionally, because I was not sure if it was a good thing to allow using make_baseline
with inherited models. I wonder if you faced the limitation in practice or is it just something that you noticed while browsing the codebase?
This is something I encountered in practice, yes. For our research use cases when benchmarking tabular data algorithms, it is common for us to need a scikit-learn-style interface (i.e. methods such as fit
, predict_proba
, etc) for all models, which is why I needed to subclass these.
What made it particularly confusing is that the ResNet class does not have this limitation, only MLP and FTTransformer, so the behavior didn't seem consistent . Of course, if there is an intentional design choice behind this, feel free to ignore :), but I would be curious to know.
In your particular case I would create a separate class like Trainer
that would implement the scikit-learn API and take an instance of torch.nn.Module
as an argument in the constructor (i.e. in the spirit of skorch). As for ResNet, this is just a "bug" and was overlooked, thanks for noticing :)
Closing the PR then?
I suppose so, thanks for the clarification :)
Return object of type cls, not MLP, in MLP.make_baseline(). Otherwise, child classes inheriting from MLP constructed using the .make_baseline() method always have type MLP (instead of the type of the child class).