Open Benjamin-eecs opened 1 year ago
Thanks for the thorough writeup @Benjamin-eecs! To check a couple things, is this a current bottleneck for your examples? We had been under the assumption that the training would be much more expensive than the training but that may not be true (or it may be fair that we're losing out on performance by not doing this)
We're also currently looking at different ways that JAX libraries build neural nets and this is a great axis I hadn't thought of before. It looks like you might be using Flax or Haiku and I was wondering if you had tried this with Equinox at all?
cc @zou3519 This seems to be the same thing that the federated learning people were asking for. I forget if we got clear answer for them
Hi there @samdow , thanks for your quick and detailed feedback.
is this a current bottleneck for your examples?
I think I can call it bottleneck in some way, we can definitely initialize the ensemble of models and optimizers using for-loop. But our TorchOpt example mainly wants to show that we can support functorch.vmap
for both initialization and training for ensemble of models. Also, in our specific usage where we want to repeat the same training process with different seeds or hyperparamters using functorch.vmap
, we think it would be better if user can write it in a functional way to code the init_fn
as a function of list of seeds or list of hyperparameters. But for now, we can only initialize a set of models with same weights.
It looks like you might be using Flax or Haiku and I was wondering if you had tried this with Equinox at all?
I am not sure I fully understood, the Jax code snippet I showed in the writeup just to present that our TorchOpt example change the functorch example into Jax-style with extra optimizer such as adam other than sgd.
Also, in our specific usage where we want to repeat the same training process with different seeds or hyperparamters using functorch.vmap, we think it would be better if user can write it in a functional way to code the init_fn as a function of list of seeds or list of hyperparameters.
To be clear, if this is the end goal, it will probably always be easier to write this as a for loop. Most of the hyperparameters are scalar values and right now we can't vmap over lists or tensors of scalar values (1D tensors that we vmap over are going to be treated as scalar tensors instead of scalars).
As an example, if we had an ensemble of models like the ones in the TorchOpt PR but where the hidden dimension was being changed:
MLP(nn.Module)
def __init__(self, hidden_dim=32, n_classes=2):
...
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
...
we would never be able to vmap over different values for hidden_dim
the Jax code snippet I showed in the writeup just to present that our TorchOpt example
I see! Thanks for that clarification. I saw classifier_fns.init
in the code snippet and assumed 😄
Motivation
We recently used TorchOpt as a functional optimizer API mentioned in functorch parallel training example to achieve batchable optimization training small neural networks on one GPU with
functorch.vmap
.With TorchOpt, we can mimic the jax implementation to use vmap on the init function: JAX:
TorchOpt + functorch:
instead of
combine_state_for_ensemble
However, any other
randomness
setting infunctorch.vmap(init_fn)
threw a bug (i.e. ifrandomness='different'
).functorch.vmap(init_fn, randomness='same')
gives identical inits for each net in the ensemble, which is not desirable if we want to train ensembles averaging across random seeds, thereforefunctorch.vmap(init_fn)
supporting different randomness settings is a needed feature in this kind of usage.cc @waterhorse1 @JieRen98 @XuehaiPan
Solution
https://github.com/metaopt/TorchOpt/pull/32 can be runned with
functorch.vmap(init_fn, randomness='different')
.Resource
TorchOpt + functorch implementation
TorchOpt + functorch implementation colab
functorch parallel training example
JAX + FLAX parallel training example
Checklist
[x] I have checked that there is no similar issue in the repo. There is a one #909 to improve
combine_state_for_ensemble
for initialization of an ensemble of models and related issue #782 to ask for implemention for this usage but my request is more on giving an specific usage that requires this feature.