dask / distributed

A distributed task scheduler for Dask
https://distributed.dask.org
BSD 3-Clause "New" or "Revised" License
1.56k stars 715 forks source link

Excessive memory use in fold-style reductions #7552

Open gjoseph92 opened 1 year ago

gjoseph92 commented 1 year ago

This graph is not parallel. It's an incremental, serial reduction. Each reducer requires the previous reducer to finish before it can run. I've set up the tasks so that reducers are significantly slower than data producers.

Therefore, there's no need to load all the inputs into memory up front. It's going to be a long time until the final input task can be used. If we load it right away, it'll just take up memory.

Screen Shot 2023-02-15 at 7 49 44 PM

As you can see, even though the load tasks were queued, far more data was loaded into memory than we can process at once.

Screen Shot 2023-02-15 at 7 49 50 PM

With larger data sizes, or if there was some other computation going on at the same time, this probably could have killed the cluster.

This was motivated by playing around with dask-ml and incremental training. AFAIU point of incremental training is to be able to train on a larger-than-memory dataset by training on it chunk-by-chunk. But it seems this scheduling behavior might defeat the purpose, since all the data will end up loaded into distributed memory anyway (as long as training is slower than data loading; quite possible with a big ML model). Hopefully spilling will save you in the real-world, but it still doesn't seem like great behavior.

No ideas yet how to address this; just interesting to think about in the context of other scheduling questions like https://github.com/dask/distributed/pull/7531

Minimal reproducer: ```python import time import dask import distributed from dask.utils import parse_bytes import distributed @dask.delayed(pure=False) def load(): return "x" * parse_bytes("50MB") @dask.delayed() def fit(prev, data): time.sleep(1) return prev + 1 roots = [load() for _ in range(50)] prev = fit(0, roots[0]) for r in roots[1:]: prev = fit(prev, r) if __name__ == "__main__": with distributed.Client( n_workers=4, threads_per_worker=1, memory_limit="1 GiB" ) as client: prev.compute() ``` Dask-ml example: ```python # model.py import time from torch import nn class MyModel(nn.Module): def __init__(self, num_units=10, nonlin=nn.ReLU()): super().__init__() self.dense0 = nn.Linear(20, num_units) self.nonlin = nonlin self.dropout = nn.Dropout(0.5) self.dense1 = nn.Linear(num_units, num_units) self.output = nn.Linear(num_units, 2) def forward(self, X, **kwargs): time.sleep(5) X = self.nonlin(self.dense0(X)) X = self.dropout(X) X = self.nonlin(self.dense1(X)) X = self.output(X) return X ``` ```python import distributed import numpy as np from dask_ml.datasets import make_classification from dask_ml.wrappers import Incremental from torch import nn from skorch import NeuralNetClassifier from model import MyModel X, y = make_classification( 100_000, 20, n_informative=10, random_state=0, chunks=(10000, 20) ) X = X.astype(np.float32) y = y.astype(np.int64) niceties = { "callbacks": False, "warm_start": False, "train_split": None, "max_epochs": 1, } net = NeuralNetClassifier( MyModel, criterion=nn.CrossEntropyLoss(), lr=0.1, **niceties, ) model = Incremental(net, scoring="accuracy") if __name__ == "__main__": with distributed.Client() as client: model.fit(X, y) ```
TomAugspurger commented 1 year ago

xref https://github.com/dask/dask-ml/issues/765, which describes this too.

gjoseph92 commented 1 year ago

Yup, same situation—thanks for the link @TomAugspurger.

Just to also note that as expected, the ordering for the graph is good: mydask