dask / dask-ml

Scalable Machine Learning with Dask
http://ml.dask.org
BSD 3-Clause "New" or "Revised" License
890 stars 255 forks source link

Provide wrappers for popular ML libraries #696

Open stsievert opened 4 years ago

stsievert commented 4 years ago

It'd be convenient to provide support for use of Keras or PyTorch models in model selection. There are two issues:

  1. Keras/PyTorch models don't conform to the Scikit-learn API.
  2. Keras models are not pickle-able.

I'm imaging this interface:

from torchvision.models import resnet18
import torch.optim as optim
from dask_ml.wrappers import PyTorchClassifier

pytorch_model = resnet18()
sklearn_model = SkorchClassifier(
    model=pytorch_model,
    model__alpha=1e-2,  # if resnet18 had a kwarg `alpha`
    optimizer=optim.SGD,
    optimizer__lr=0.1,
)

Related issues/PRs Same complaint in dask/distributed: https://github.com/dask/distributed/issues/3873

stsievert commented 4 years ago

I think this is possible with these wrappers:

edit these libraries are discussed below:

SciKeras and Skorch are now mentioned in Dask-ML's documentation on wrappers (see 1 and 2).

adriangb commented 4 years ago

If you're looking for a way to make Keras models conform to the scikit-learn API, check out SciKeras (full disclosure: I'm the author)

TomAugspurger commented 4 years ago

Thanks for the link Adrian. To minimize our maintenance burden, I'd aim for the goal that our model_selection estimators work with any model implementing the scikit-learn interface, and encourage the development / use of wrappers like skorch and SciKeras.

On top of that, we have the additional burden of these models needing to work well with distributed's serialization. To the extent possible, that functionality should be in the projects themselves (making Keras models picklable) or in distributed.

mrocklin commented 4 years ago

How hard is it to support PyTorch/Keras fit/predict APIs? If this is as simple as making a function like the following, then I would be in favor

def fit(estimator, X, y=None):
    if hasattr(estimator, "fit"):
        return estimator.fit(X, y)
    elif hasattr(estimator, ...): #  pytorch-like
        return ...
    elif hasattr(estimator, ...): # keras-like
        return ...
mrocklin commented 4 years ago

For serialization I think that we have a decent Pytorch serializer in distributed (early work from @stsievert if I recall correctly). I don't think that we have anything for Keras today.

mrocklin commented 4 years ago

Serialization is also maybe something that we could ask for help from the RAPIDS folks like @quasiben @jakirkham @pentschev . It's not RAPIDS obviously, but these are often GPU related and that team is familiar with these sorts of issues.

adriangb commented 4 years ago

I'm not 100% sure what the goal is here (I just came from the discussion in tensorflow/tensorflow#39609) but SciKeras adds serialization support to the Keras models it wraps. Ex:

from scikeras import KerasClassifier

keras_model = ...  # some keras model object, can be Sequential or Functional

wrapped_model = KerasRegressor(keras_model)  # a serializable, scikit-learn api compliant estimator

So I guess you could just tell your users to wrap their Keras models before using them with dask-ml?

quasiben commented 4 years ago

Serialization is definitely something RAPIDS cares about. scikeras looks interesting -- @adriangb do you know if it forces a host to device transfer ? Does it support the __cuda_array_interface__ ? If so, I believe things are a lot easier for us.

cc @JohnZed maybe pytorch serialization is something cuML would also care about

mrocklin commented 4 years ago

Current pytorch serialization is here: https://github.com/dask/distributed/blob/master/distributed/protocol/torch.py

It looks like it forces things to numpy though, and so may not be GPU-optimized.

Rather than scikeras I'm still curious if we can make things more torch/tf/keras-native cheaply

On Mon, Jul 13, 2020 at 8:35 AM Benjamin Zaitlen notifications@github.com wrote:

Serialization is definitely something RAPIDS cares about. scikeras looks interesting -- @adriangb https://github.com/adriangb do you know if it forces a host to device transfer ? Does it support the cuda_array_interface ? If so, I believe things are a lot easier for us.

cc @JohnZed https://github.com/JohnZed maybe pytorch serialization is something cuML would also care about

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/696#issuecomment-657631883, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACKZTBFNULDR66W3MEA6XDR3MSVHANCNFSM4OX6CD2Q .

stsievert commented 4 years ago

So I guess you could just tell your users to wrap their Keras models before using them with dask-ml?

That's the plan. To do that, models need to support serialization and implement partial_fit (see https://github.com/adriangb/scikeras/pull/17).

pytorch serialization is something cuML would also care about

may not be GPU-optimized.

PyTorch has serialization support, even though they recently tried to remove it! https://github.com/pytorch/pytorch/issues/38597 Skorch wraps the PyTorch, and it looks like the support GPUs: skorch/net.py#L1608.

adriangb commented 4 years ago

do you know if it forces a host to device transfer ? Does it support the __cuda_array_interface__ ? If so, I believe things are a lot easier for us.

To be honest, I am not familiar with these terms. All SciKeras does is implement copy.deepcopy and pickle compatible serialization. It does not have a __cuda_array_interface__ method, so I think the answer is no.

models need to support serialization and implement partial_fit (see adriangb/scikeras#17).

Will take a look tonight!

jakirkham commented 4 years ago

How difficult would it be to implement pickling (like Matt did for PyTorch ( https://github.com/pytorch/pytorch/pull/9184 )) for Keras as well? There's a lot of value gained by supporting standard Python protocols. Not to say there may not be additional gains with Dask serialization. Just that having this standard protocol working would make interop with various distributed computing libraries (including Dask) easier.

stsievert commented 4 years ago

How difficult would it be to implement pickling ... for Keras as well?

SciKeras has an implementation at scikeras/wrappers.py#L87. There's currently an open PR to merge this into Tensorflow/Keras master: https://github.com/tensorflow/tensorflow/pull/39609

Does that answer your question?

stsievert commented 3 years ago

Here's two more PyTorch wrappers:

  1. adadamp, which provides usage of Dask clusters with PyTorch models and presents a Scikit-learn interface
  2. saturncloud/dask-pytorch-ddp, which allows use of a Dask cluster with PyTorch distributed code.