pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Feature Request: make_functional for torch.optim.Optimizer #372

Open teddykoker opened 2 years ago

teddykoker commented 2 years ago

Thank you for the great work on this library!

One common pattern I am noticing (a b) is using the gradients from the grad() function to perform the optimization step:

grads = grad(loss_fn)(params, ...)
params = [p - g * learning_rate for p, g, in zip(params, grads)]

While this is relatively straightforward to do if vanilla mini-batch gradient descent is desired, there seems to be no way to use other optimization methods without:

  1. Manually setting the .grad for each parameter and then use the class based optimizer in torch.optim
  2. Use the functional interface for the optimizer implemented in torch.optim._functional while manually initializing and passing the necessary elements of the state
  3. Implementing the optimizer yourself

One possible solution to this problem would be to extend make_functional() or createmake_functional_optimizer() to support torch.optim.Optimizer. A potential API could look something like:

optimizer = torch.optim.Adam(params, lr=3e-4)

# state contains the optimizer state
# update_fn is a stateless func that will return a new state and params given gradients and a state
opt_state, update_fn = make_functional(optimizer)

grads = grad(loss_fn)(params, ...)

# update params and state using update_fn
params, opt_state = update_fn(grads, opt_state)

I believe the above would be possible using a similar method used already in make_functional with nn.Module(). Obviously there are a number of ways the API could work (e.g JAX and optax both have slightly differently structured functional optimizer APIs), but I thought it would be good to gauge interest and/or see if such a thing would be worth implementing!

zou3519 commented 2 years ago

Thanks for the issue, @teddykoker. This seems like a reasonable API. We've had some decision paralysis in the past on what the best API for this would look like, but having something working is always better than having something that doesn't work :).

One question is: If I understand the proposal, update_fn(grads, opt_state) does modify opt_state in-place but rather returns a new opt_state. Do you think it'll be a problem that this doesn't do the in-place modification? I don't have a sense of how large optimizer states can be in general.

teddykoker commented 2 years ago

Thanks for the reply @zou3519 You are correct in that the above proposal does not modify the optimizer state in place. My reasoning behind this was to conform to a more "functional" style and avoid side effects; however this certainly doesn't have to be the case.

Regarding size of optimizer states, I believe they are usually on the order of magnitude of the size of the model. For example, Adam maintains a running mean of the gradient values and another running mean of the squared gradient values, resulting in a state roughly two times the size of the model itself. If I understand correctly, PyTorch will need to temporarily have enough memory to store both states, even if the old one is now longer referenced, which could cause memory issues.

waterhorse1 commented 2 years ago

Hi @teddykoker, we open source TorchOpt, which can be combined with functorch to conduct functional optimization. It can be like the feature you want.

waterhorse1 commented 2 years ago

Recently, we also incorporate vmap, one of the major features of Functorch into TorchOpt, by which we achieve batchable optimization. We have a pull request here and provide a colab to play with it.