pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Error about using a grad transform with in-place operation is inconsistent with and without DDP #1112

Open XuchanBao opened 1 year ago

XuchanBao commented 1 year ago

Hi,

I was using torch.func in pytorch 2.0 to compute the Hessian-vector product of a neural network.

I first used torch.func.functional_call to define a functional version of the neural network model, and then proceeded to use torch.func.jvp and torch.func.grad to compute the hvp.

The above works when I was using one gpu without parallel processing. However, when I wrapped the model with Distributed Data Parallel (DDP), it gave the following error:

*** RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::copy_) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.

I am confused about this error, because if there were indeed such in-place operations (which I couldn't find in my model.forward() code), I'd expect this error to occur regardless of DDP. Given the inconsistent behaviour, can I still trust the hvp result when I wasn't using DDP?

My torch version: is 2.0.0.dev20230119+cu117

zou3519 commented 1 year ago

@XuchanBao do you have a script that reproduces the problem that we could take a look at?

DistributedDataParallel does some extra things to the model, so it's likely that your hvp result is correct but the DDP extra things are interacting badly with vmap.