Open jrbourbeau opened 3 years ago
I just ran across https://ml.dask.org/pytorch which makes me think today the best practice is to use skorch to integrate pytorch models with Incremental
and HyperbandSearchCV
. Though I'm still curious to hear there are other thoughts on this topic or if there's an appetite for more first-class pytorch support
Wrapping PyTorch models with Skorch has been my first instinct in a couple projects. Generally, this works well. However, Skorch can have some performance overheads as mentioned on their "Performance" doc page: https://skorch.readthedocs.io/en/stable/user/performance.html
I'd recommend following that page; it's decent, and mentions a couple of use cases that burned me. Without mitigation, Skorch has poor performance in these cases:
callbacks="disable"
in the most recent version.For my use case, I eventually wound up only using Skorch for initialization (at _embed.py#L88-L101) inside a hand crafted Scikit-learn compatible estimator (e.g, I manually call skorch_nn.module_.forward
inside my partial_fit
function).
I was speaking with a group today who is using Dask for I/O and preprocessing of a larger-than-memory dataset and then wants to train a single pytorch model on that large dataset. My initial thought was to look at using Dask-ML's
Incremental
meta-estimator, however after reading the docs more closely I realized thatIncremental
only supports models with apartial_fit
method.I'm wondering what are some best practices for training a pytorch model (or running a hyperparameter optimization) on a large Dask collection today? Should users use skorch to wrap their pytorch models for a scikit-learn compatible API? Is there some other approach users should take?
cc'ing @stsievert as you may have thoughts on this topic