dask / dask-ml

Scalable Machine Learning with Dask
http://ml.dask.org
BSD 3-Clause "New" or "Revised" License
890 stars 255 forks source link

Using PyTorch models for training and hyperparameter tuning with Dask #836

Open jrbourbeau opened 3 years ago

jrbourbeau commented 3 years ago

I was speaking with a group today who is using Dask for I/O and preprocessing of a larger-than-memory dataset and then wants to train a single pytorch model on that large dataset. My initial thought was to look at using Dask-ML's Incremental meta-estimator, however after reading the docs more closely I realized that Incremental only supports models with a partial_fit method.

I'm wondering what are some best practices for training a pytorch model (or running a hyperparameter optimization) on a large Dask collection today? Should users use skorch to wrap their pytorch models for a scikit-learn compatible API? Is there some other approach users should take?

cc'ing @stsievert as you may have thoughts on this topic

jrbourbeau commented 3 years ago

I just ran across https://ml.dask.org/pytorch which makes me think today the best practice is to use skorch to integrate pytorch models with Incremental and HyperbandSearchCV. Though I'm still curious to hear there are other thoughts on this topic or if there's an appetite for more first-class pytorch support

stsievert commented 3 years ago

Wrapping PyTorch models with Skorch has been my first instinct in a couple projects. Generally, this works well. However, Skorch can have some performance overheads as mentioned on their "Performance" doc page: https://skorch.readthedocs.io/en/stable/user/performance.html

I'd recommend following that page; it's decent, and mentions a couple of use cases that burned me. Without mitigation, Skorch has poor performance in these cases:

  1. When either small data or small models are used. Converting datasets beforehand reduces the overhead (which is significant with small data/models). If the time spent in PyTorch code is small, this overhead is significant. In Dask parlance, this happens with small chunks.
  2. When their callbacks are enabled (the default). They can be disabled by setting callbacks="disable" in the most recent version.

For my use case, I eventually wound up only using Skorch for initialization (at _embed.py#L88-L101) inside a hand crafted Scikit-learn compatible estimator (e.g, I manually call skorch_nn.module_.forward inside my partial_fit function).