Closed RaulPPelaez closed 8 months ago
Saving and loading the model each iteration solves the issue, and it is in fact faster than not doing it:
for i in range(0, 10):
print("Running iteration {}".format(i))
import io
buffer = io.BytesIO()
torch.jit.save(model, buffer)
buffer.seek(0)
model = torch.jit.load(buffer).to(device="cuda")
y, neg_dy = model(z, pos, batch)
This is the offending line. Making tensor_m = tensor or using 0 layers makes the problem go away.
https://github.com/torchmd/torchmd-net/blob/fd8395463370c6b0510d4fe24b25d937382247a1/torchmdnet/models/tensornet.py#L331 Replacing it by this does NOT solve the issue:
tensor_m = torch.zeros_like(tensor).scatter_add(0, (edge_index[0])[:,None, None, None].expand_as(msg), msg)
But my suspicion is scatter_add is the same thing being used by scatter anyway... EDIT: Not saving the graph when calling grad also makes the error go away:
if self.derivative:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
dy = grad(
[y],
[pos],
grad_outputs=grad_outputs,
create_graph=False,
retain_graph=False,
)[0]
5 is a magic number we have encountered before when dealing with CUDA torch models. See for instance https://github.com/openmm/openmm-torch/pull/122
I just made the connection with this issue now while reading about torch.jit fusion operation mechanism with NVFuser: https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/cuda/README.md#general-ideas-of-debug-no-fusion
What I believe is going on here is that there is some kind of collision between jit.script, the nvfuser backend and autograd.
This issue goes away with pytorch>2.0.0, so I am going to assume it is a bug there. Will leave this open until there is a conda-forge package for pytorch 2.1.
Closing this as there is a package for 2.1 already.
This test will run exactly 4 iterations and then print an error:
The error: