pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

[Feature Request] Support `different` randomness settings to train an ensemble of models with TorchOpt #996

Open Benjamin-eecs opened 1 year ago

Benjamin-eecs commented 1 year ago

Motivation

We recently used TorchOpt as a functional optimizer API mentioned in functorch parallel training example to achieve batchable optimization training small neural networks on one GPU with functorch.vmap.

With TorchOpt, we can mimic the jax implementation to use vmap on the init function: JAX:

def init_fn(input_shape, seed):
    rng = jr.PRNGKey(seed)                                     # jr = jax.random
    dummy_input = jnp.ones((1, *input_shape))
    params = classifier_fns.init(rng, dummy_input)['params']   # do shape inference
    optimizer_def = optim.Adam(learning_rate=1e-3)
    optimizer = optimizer_def.create(params)
    return optimizer
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
model_states = parallel_init_fn((2,), seeds)

TorchOpt + functorch:

def init_fn(model_idx):
    _, weights = functorch.make_functional(MLPClassifier().to(DEVICE))
    opt_state = torchopt.adam(lr=0.2).init(weights)
    return weights, opt_state
parallel_init_fn = functorch.vmap(init_fn, randomness='same') # only 'same' works
batched_weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))

instead of combine_state_for_ensemble

def init_fn(num_models):
    models = [MLPClassifier().to(DEVICE) for _ in range(num_models)]
    _, params, _ = combine_state_for_ensemble(models)
    return params
batched_weights = init_fn(num_models=2)

However, any other randomness setting in functorch.vmap(init_fn) threw a bug (i.e. if randomness='different').

Traceback (most recent call last):
  File "parallel_train_torchopt.py", line 196, in <module>
    functorch_original.test_parallel_train_step_fn(num_models=2)
  File "parallel_train_torchopt.py", line 136, in test_parallel_train_step_fn
    weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/functorch/_src/vmap.py", line 365, in wrapped
    batched_outputs = func(*batched_inputs, **kwargs)
  File "parallel_train_torchopt.py", line 109, in init_fn
    _, weights = make_functional(MLPClassifier().to(DEVICE))
  File "parallel_train_torchopt.py", line 49, in __init__
    self.fc1 = nn.Linear(2, self.hidden_dim)
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 101, in __init__
    self.reset_parameters()
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 107, in reset_parameters
    init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/nn/init.py", line 412, in kaiming_uniform_
    return tensor.uniform_(-bound, bound)
RuntimeError: vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. If this is necessary for your usage, please file an issue with functorch.

functorch.vmap(init_fn, randomness='same') gives identical inits for each net in the ensemble, which is not desirable if we want to train ensembles averaging across random seeds, therefore functorch.vmap(init_fn) supporting different randomness settings is a needed feature in this kind of usage.

cc @waterhorse1 @JieRen98 @XuehaiPan

Solution

https://github.com/metaopt/TorchOpt/pull/32 can be runned with functorch.vmap(init_fn, randomness='different').

Resource

samdow commented 1 year ago

Thanks for the thorough writeup @Benjamin-eecs! To check a couple things, is this a current bottleneck for your examples? We had been under the assumption that the training would be much more expensive than the training but that may not be true (or it may be fair that we're losing out on performance by not doing this)

We're also currently looking at different ways that JAX libraries build neural nets and this is a great axis I hadn't thought of before. It looks like you might be using Flax or Haiku and I was wondering if you had tried this with Equinox at all?

cc @zou3519 This seems to be the same thing that the federated learning people were asking for. I forget if we got clear answer for them

Benjamin-eecs commented 1 year ago

Hi there @samdow , thanks for your quick and detailed feedback.

is this a current bottleneck for your examples?

I think I can call it bottleneck in some way, we can definitely initialize the ensemble of models and optimizers using for-loop. But our TorchOpt example mainly wants to show that we can support functorch.vmap for both initialization and training for ensemble of models. Also, in our specific usage where we want to repeat the same training process with different seeds or hyperparamters using functorch.vmap, we think it would be better if user can write it in a functional way to code the init_fn as a function of list of seeds or list of hyperparameters. But for now, we can only initialize a set of models with same weights.

It looks like you might be using Flax or Haiku and I was wondering if you had tried this with Equinox at all?

I am not sure I fully understood, the Jax code snippet I showed in the writeup just to present that our TorchOpt example change the functorch example into Jax-style with extra optimizer such as adam other than sgd.

samdow commented 1 year ago

Also, in our specific usage where we want to repeat the same training process with different seeds or hyperparamters using functorch.vmap, we think it would be better if user can write it in a functional way to code the init_fn as a function of list of seeds or list of hyperparameters.

To be clear, if this is the end goal, it will probably always be easier to write this as a for loop. Most of the hyperparameters are scalar values and right now we can't vmap over lists or tensors of scalar values (1D tensors that we vmap over are going to be treated as scalar tensors instead of scalars).

As an example, if we had an ensemble of models like the ones in the TorchOpt PR but where the hidden dimension was being changed:

MLP(nn.Module)
    def __init__(self, hidden_dim=32, n_classes=2):
        ...
        self.fc1 = nn.Linear(2, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
  ...

we would never be able to vmap over different values for hidden_dim

the Jax code snippet I showed in the writeup just to present that our TorchOpt example

I see! Thanks for that clarification. I saw classifier_fns.init in the code snippet and assumed 😄