Open dellis23 opened 1 year ago
That's private, no? Is it intended to be used outside the pytorch repo itself?
Also, as far as I can tell, it's not the same as the jax tree_map
, which allows for multiple inputs, letting you do things like aggregations across the inputs.
Functorch, torch fx and torchvision are using this submodule. @zou3519 can give more visibility on when this module will be public.
Also, as far as I can tell, it's not the same as the jax tree_map, which allows for multiple inputs, letting you do things like aggregations across the inputs.
Maybe, one can use pytree flatten and unflatten methods to get jax like behavior.
@dellis23, eventually we want to make some form of tree_map available in PyTorch. This could be re-using JAX's version or using the implementation in torch.utils._pytree.
Also, as far as I can tell, it's not the same as the jax tree_map, which allows for multiple inputs, letting you do things like aggregations across the inputs.
You're right that we don't support everything there yet.
If the goal right now is to test the models, please feel free to use torch.utils._pytree
(it works and we expect to maintain BC for its API). See https://github.com/pytorch/pytorch/issues/65761 for the larger tracking issue.
@zou3519 Do you think it might be possible to use tree_map
to update parameters? Instead of doing,
optim.zero_grad()
loss.backward()
optim.step()
Could you do something like this?
grad, loss = grad_and_value(loss_fn, argnums=0)(params, x) #returns scalar loss
torch.utils._pytree( lambda p, g: p - lr * g, params, grad) #update params with SGD
The only reason I ask is if it might be quicker to precondition gradients via a tree map than defining everything within optim.step?
We could add something like that, yes.
if it might be quicker to precondition gradients via a tree map than defining everything within optim.step?
Why would this be faster?
@zou3519, Let's say I have some network that represents an R^N -> R^1 function, and I want to precondition the gradient of the loss by its Fisher information matrix (FIM).
The FIM is the expectation of the outer product of the Jacobian of the network so something like,
fim = torch.einsum("bi,bj->ij", jacobian, jacobian)/batch
.
The loss calculation is computed via taking the mean, then backpropagating to compute the gradient
but the FIM requires all the samples so in order to precondition the gradients I would need to recompute per-sample gradients. Therefore, the current method would be something like,
y = loss_fn(x) #1st forward pass
loss = torch.mean(y)
loss.backward() #compute grad (1st backward pass)
per_sample_jacobian = vmap(grad(loss_fn, argnums=0), in_dims=(None, 0))(params, x) #2nd forward/backward pass
fim = torch.einsum("bi,bj->ij", per_sample_jacobian, per_sample_jacobian)/batch
# then precondition via torch.linalg.solve(), and update parameters
Whereas a tree map approach might look like,
per_sample_jacobian, losses = vmap(grad_and_value(loss_fn, argnums=0))(parmas, x) #1st foward/backward pass
loss = torch.mean(losses, dim=0)
grad = torch.mean(per_sample_jacobian, dim=0)
fim = torch.einsum("bi,bj->ij", per_sample_jacobian, per_sample_jacobian)/batch
# then precondition via torch.linalg.solve(), and update parameters
One initial problem that comes to mind is that because loss.backward()
isn't called, the .grad
attribute isn't populated for optim.step()
and so I still have to call .backward()
regardless of this speed-up which kinda defeats the point of it! This also assumes all samples in your batch are independent which may be an issue for some use cases.
@zou3519 Do you think it might be possible to use
tree_map
to update parameters? Instead of doing,optim.zero_grad() loss.backward() optim.step()
Could you do something like this?
grad, loss = grad_and_value(loss_fn, argnums=0)(params, x) #returns scalar loss torch.utils._pytree( lambda p, g: p - lr * g, params, grad) #update params with SGD
The only reason I ask is if it might be quicker to precondition gradients via a tree map than defining everything within optim.step?
Hi @AlphaBetaGamma96, please checkout TorchOpt and Optree, TorchOpt provide user-friendly API for your use case:
class Net(nn.Module): ...
class Loader(DataLoader): ...
net = Net() # init
loader = Loader()
optimizer = torchopt.adam()
model, params = functorch.make_functional(net) # use functorch extract network parameters
opt_state = optimizer.init(params) # init optimizer
xs, ys = next(loader) # get data
pred = model(params, xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
grads = torch.autograd.grad(loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = torchopt.apply_updates(params, updates) # update network parameters
I'm working on testing some models using
functorch
along withtorch-mlir
and IREE. I don't see an analog of jax'stree_map
. Is this something it makes sense forfunctorch
to implement, or is there a recommended alternative?
Hi @dellis23, if you are interested, please checkout OpTree for analog of jax pytree and performant version of torch.utils._pytree
.
I'm working on testing some models using
functorch
along withtorch-mlir
and IREE. I don't see an analog of jax'stree_map
. Is this something it makes sense forfunctorch
to implement, or is there a recommended alternative?