pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Performant way to initialize an ensemble of models #909

Open zou3519 opened 2 years ago

zou3519 commented 2 years ago

Right now, to initialize an ensemble of e.g. 350 models, we first create 350 models and then combine their states together with combine_state_for_ensemble. This leaves some performance on the table; the fastest thing we could do is initialize the combined state in one go.

This might not be too difficult to do. Idea from discussion with @Chillee is:

msaroufim commented 2 years ago

@zou3519 I'm curious if you have any thoughts on how scheduling would work then. If all the models are being combined into a single output then you're going to waiting for the slowest model and if you have a DAG of models then again you'll be bottlenecked by slower ones.

Is there any way to associate a given model with some resources so you could say something like

combine_state_for_ensemble([m1,m2,m2], [0.2,0.8,3]) which would mean combine into a single model where m1 has 20% of a GPU and model 3 has 3 GPUs available?

zou3519 commented 2 years ago

In order to combine the models together for ensembling with vmap, each model must call exactly the same sequence of PyTorch operations (otherwise, vmap will not work). So in that case there isn't a "slowest model", when run separately the models are expected to take roughly the same amount of time.