dask / dask-ml

Scalable Machine Learning with Dask
http://ml.dask.org
BSD 3-Clause "New" or "Revised" License
899 stars 256 forks source link

ParallelPostFit excessive scheduler memory use and CancelledError #842

Open rikturr opened 3 years ago

rikturr commented 3 years ago

Notebook with MCVE and all notes: https://nbviewer.jupyter.org/gist/rikturr/43336377678018d01d4f21f19dd7ef11

When using ParallelPostFit to train with pandas/numpy objects then predict on dask objects, I noticed that the scheduler memory use runs extremely high. Many times I would get a CancelledError and scheduler dying when calling .predict() with pretty small data sizes (refer to notebook for full code with outputs):

X_train, X_test, y_train, y_test = ...
rf = ParallelPostFit(
    RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=-1)
)
_ = rf.fit(X_train, y_train)

preds = rf.predict(X_test)
_ = preds.compute()  # failure happens on this line after ~40 minutes

The scheduler memory balloons to npartitions of X_test * size of rf, which can get into the multiple GBs very fast. I noticed that each time an operation would get called on preds, this memory exchange would happen again. I realize that this is because ParallePostFit uses map_partitions behind the scenes, but does not broadcast the model objects. This causes Dask to send the object through the scheduler to each worker each time you do something with preds (unless of course you persisted it).

Workaround is to broadcast the model object then use map_partitions yourself instead of the ParallelPostFit wrapper:

rf = RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=-1)
_ = rf.fit(X_train, y_train)

rf_fut = client.scatter(rf, broadcast=True)

def dask_predict(df, model):
    return model.predict(df)

preds = X_test.map_partitions(
    dask_predict,
    model=rf_fut,
    meta=np.array([1])
)

I plan to follow up with a PR to fix this in the ParallelPostFit class

rikturr commented 3 years ago

Just confirmed this happens with a LocalCluster too it is easier to reproduce that way