pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Add support for `tree_map` or document recommended alternative #1030

Open dellis23 opened 1 year ago

dellis23 commented 1 year ago

I'm working on testing some models using functorch along with torch-mlir and IREE. I don't see an analog of jax's tree_map. Is this something it makes sense for functorch to implement, or is there a recommended alternative?

vfdev-5 commented 1 year ago

Please check : https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py#L190

dellis23 commented 1 year ago

That's private, no? Is it intended to be used outside the pytorch repo itself?

Also, as far as I can tell, it's not the same as the jax tree_map, which allows for multiple inputs, letting you do things like aggregations across the inputs.

vfdev-5 commented 1 year ago

Functorch, torch fx and torchvision are using this submodule. @zou3519 can give more visibility on when this module will be public.

Also, as far as I can tell, it's not the same as the jax tree_map, which allows for multiple inputs, letting you do things like aggregations across the inputs.

Maybe, one can use pytree flatten and unflatten methods to get jax like behavior.

zou3519 commented 1 year ago

@dellis23, eventually we want to make some form of tree_map available in PyTorch. This could be re-using JAX's version or using the implementation in torch.utils._pytree.

Also, as far as I can tell, it's not the same as the jax tree_map, which allows for multiple inputs, letting you do things like aggregations across the inputs.

You're right that we don't support everything there yet.

If the goal right now is to test the models, please feel free to use torch.utils._pytree (it works and we expect to maintain BC for its API). See https://github.com/pytorch/pytorch/issues/65761 for the larger tracking issue.

AlphaBetaGamma96 commented 1 year ago

@zou3519 Do you think it might be possible to use tree_map to update parameters? Instead of doing,

optim.zero_grad()
loss.backward()
optim.step()

Could you do something like this?

grad, loss = grad_and_value(loss_fn, argnums=0)(params, x) #returns scalar loss
torch.utils._pytree( lambda p, g: p - lr * g, params, grad) #update params with SGD

The only reason I ask is if it might be quicker to precondition gradients via a tree map than defining everything within optim.step?

zou3519 commented 1 year ago

We could add something like that, yes.

if it might be quicker to precondition gradients via a tree map than defining everything within optim.step?

Why would this be faster?

AlphaBetaGamma96 commented 1 year ago

@zou3519, Let's say I have some network that represents an R^N -> R^1 function, and I want to precondition the gradient of the loss by its Fisher information matrix (FIM).

The FIM is the expectation of the outer product of the Jacobian of the network so something like, fim = torch.einsum("bi,bj->ij", jacobian, jacobian)/batch.

The loss calculation is computed via taking the mean, then backpropagating to compute the gradient but the FIM requires all the samples so in order to precondition the gradients I would need to recompute per-sample gradients. Therefore, the current method would be something like,

y = loss_fn(x) #1st forward pass
loss = torch.mean(y) 
loss.backward() #compute grad (1st backward pass)

per_sample_jacobian = vmap(grad(loss_fn, argnums=0), in_dims=(None, 0))(params, x) #2nd forward/backward pass

fim = torch.einsum("bi,bj->ij", per_sample_jacobian, per_sample_jacobian)/batch
# then precondition via torch.linalg.solve(), and update parameters

Whereas a tree map approach might look like,

per_sample_jacobian, losses = vmap(grad_and_value(loss_fn, argnums=0))(parmas, x) #1st foward/backward pass

loss = torch.mean(losses, dim=0)
grad = torch.mean(per_sample_jacobian, dim=0)

fim = torch.einsum("bi,bj->ij", per_sample_jacobian, per_sample_jacobian)/batch
# then precondition via torch.linalg.solve(), and update parameters

One initial problem that comes to mind is that because loss.backward() isn't called, the .grad attribute isn't populated for optim.step() and so I still have to call .backward() regardless of this speed-up which kinda defeats the point of it! This also assumes all samples in your batch are independent which may be an issue for some use cases.

Benjamin-eecs commented 1 year ago

@zou3519 Do you think it might be possible to use tree_map to update parameters? Instead of doing,

optim.zero_grad()
loss.backward()
optim.step()

Could you do something like this?

grad, loss = grad_and_value(loss_fn, argnums=0)(params, x) #returns scalar loss
torch.utils._pytree( lambda p, g: p - lr * g, params, grad) #update params with SGD

The only reason I ask is if it might be quicker to precondition gradients via a tree map than defining everything within optim.step?

Hi @AlphaBetaGamma96, please checkout TorchOpt and Optree, TorchOpt provide user-friendly API for your use case:


class Net(nn.Module): ...

class Loader(DataLoader): ...

net = Net()  # init
loader = Loader()
optimizer = torchopt.adam()

model, params = functorch.make_functional(net)           # use functorch extract network parameters
opt_state = optimizer.init(params)                       # init optimizer

xs, ys = next(loader)                                    # get data
pred = model(params, xs)                                 # forward
loss = F.cross_entropy(pred, ys)                         # compute loss

grads = torch.autograd.grad(loss, params)                # compute gradients
updates, opt_state = optimizer.update(grads, opt_state)  # get updates
params = torchopt.apply_updates(params, updates)         # update network parameters
Benjamin-eecs commented 1 year ago

I'm working on testing some models using functorch along with torch-mlir and IREE. I don't see an analog of jax's tree_map. Is this something it makes sense for functorch to implement, or is there a recommended alternative?

Hi @dellis23, if you are interested, please checkout OpTree for analog of jax pytree and performant version of torch.utils._pytree.