Open kartik4949 opened 10 months ago
As a first step, we can expose the batch size to the user.
Then we can perform optimization to see if we can estimate the ideal batch size automatically based on the available memory.
We also need to account for the parallel processing of multiple workers (are they going to process the same chunk or different ones ?)
When the model has predicted lets say for 1 million data points, in
components/model.py : predict method
model stores the outputs of this 1 million data points into a single list
outputs
which will OOM when it exceeds memory.refer : superduper/components/model.py: predict method.
Same thing happens in model inputs All inputs are loaded on memory before passing it to model, inputs are packed into a e.g Dataloader (refer: ext/torch/model.py: _predict method)
We need to chunk the model inputs in the database and iterate over a chunk and pass it for model prediction.