Open zou3519 opened 2 years ago
@zou3519 I'm curious if you have any thoughts on how scheduling would work then. If all the models are being combined into a single output then you're going to waiting for the slowest model and if you have a DAG of models then again you'll be bottlenecked by slower ones.
Is there any way to associate a given model with some resources so you could say something like
combine_state_for_ensemble([m1,m2,m2], [0.2,0.8,3])
which would mean combine into a single model where m1
has 20%
of a GPU and model 3 has 3 GPUs available?
In order to combine the models together for ensembling with vmap, each model must call exactly the same sequence of PyTorch operations (otherwise, vmap will not work). So in that case there isn't a "slowest model", when run separately the models are expected to take roughly the same amount of time.
Right now, to initialize an ensemble of e.g. 350 models, we first create 350 models and then combine their states together with
combine_state_for_ensemble
. This leaves some performance on the table; the fastest thing we could do is initialize the combined state in one go.This might not be too difficult to do. Idea from discussion with @Chillee is: