yandex-research / rtdl

Research on Tabular Deep Learning: Papers & Packages
Apache License 2.0
888 stars 98 forks source link

Fix MLP.make_baseline() return type #40

Closed jpgard closed 1 year ago

jpgard commented 1 year ago

Return object of type cls, not MLP, in MLP.make_baseline(). Otherwise, child classes inheriting from MLP constructed using the .make_baseline() method always have type MLP (instead of the type of the child class).

jpgard commented 1 year ago

Note: also makes a similar change for FTTransformer.

Yura52 commented 1 year ago

Thank you for the proposal! In fact, this limitation was introduced intentionally, because I was not sure if it was a good thing to allow using make_baseline with inherited models. I wonder if you faced the limitation in practice or is it just something that you noticed while browsing the codebase?

jpgard commented 1 year ago

This is something I encountered in practice, yes. For our research use cases when benchmarking tabular data algorithms, it is common for us to need a scikit-learn-style interface (i.e. methods such as fit, predict_proba, etc) for all models, which is why I needed to subclass these.

What made it particularly confusing is that the ResNet class does not have this limitation, only MLP and FTTransformer, so the behavior didn't seem consistent . Of course, if there is an intentional design choice behind this, feel free to ignore :), but I would be curious to know.

Yura52 commented 1 year ago

In your particular case I would create a separate class like Trainer that would implement the scikit-learn API and take an instance of torch.nn.Module as an argument in the constructor (i.e. in the spirit of skorch). As for ResNet, this is just a "bug" and was overlooked, thanks for noticing :)

Yura52 commented 1 year ago

Closing the PR then?

jpgard commented 1 year ago

I suppose so, thanks for the clarification :)