Open XuchanBao opened 1 year ago
@XuchanBao do you have a script that reproduces the problem that we could take a look at?
DistributedDataParallel does some extra things to the model, so it's likely that your hvp result is correct but the DDP extra things are interacting badly with vmap.
Hi,
I was using
torch.func
in pytorch 2.0 to compute the Hessian-vector product of a neural network.I first used
torch.func.functional_call
to define a functional version of the neural network model, and then proceeded to usetorch.func.jvp
andtorch.func.grad
to compute the hvp.The above works when I was using one gpu without parallel processing. However, when I wrapped the model with Distributed Data Parallel (DDP), it gave the following error:
*** RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::copy_) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.
I am confused about this error, because if there were indeed such in-place operations (which I couldn't find in my model.forward() code), I'd expect this error to occur regardless of DDP. Given the inconsistent behaviour, can I still trust the hvp result when I wasn't using DDP?
My torch version: is
2.0.0.dev20230119+cu117