taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.05k stars 2.26k forks source link

Wrong gradient when using taichi autodiff.grad and pytorch autodiff.function together. #8534

Open zjcs opened 1 month ago

zjcs commented 1 month ago

Describe the bug A clear and concise description of what the bug is, ideally within 20 words.

Wrong gradient when using taichi autodiff.grad and pytorch autodiff.function together.

To Reproduce Please post a minimal sample code to reproduce the bug. The developer team will put a higher priority on bugs that can be reproduced within 20 lines of code. If you want a prompt reply, please keep the sample code short and representative.

# sample code here

import taichi as ti
import torch

ti.init(arch=ti.cpu)

@ti.kernel
def func_x2(x:ti.types.ndarray(ndim=1),
           y :ti.types.ndarray(ndim=1),
):
    for i in ti.ndrange(x.shape[0]):
        y[0] += x[i]**2

class TaichiKernel(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.zeros((1,), dtype=torch.float32, requires_grad=True)
        print("taichi kernel forward:", x.grad, y.grad)
        func_x2(x, y)
        ctx.save_for_backward(x)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        # x, y = ctx.saved_tensors
        x, = ctx.saved_tensors
        # print("taichi kernel backward:", x.grad, y.grad, grad_y)
        print("taichi kernel backward+1:", x.grad, grad_y)
        y = torch.zeros((1,), dtype=torch.float32, requires_grad=True)
        y.grad = grad_y
        print("taichi kernel backward+2:", x.grad, y.grad, grad_y)
        func_x2.grad(x, y)
        print("taichi kernel backward+3:", x.grad, y.grad, grad_y)
        return x.grad

class TaichiModule(torch.nn.Module):
    def forward(self, x):
        return TaichiKernel().apply(x)

print("===================>")
x = torch.arange(4, dtype=torch.float32, requires_grad=True)
y = TaichiModule()(x)
loss = y.sum()
loss.backward()

print("y.grad final", y.grad)
print("x.grad final", x.grad)

Log/Screenshots Please post the full log of the program (instead of just a few lines around the error message, unless the log is > 1000 lines). This will help us diagnose what's happening. For example:

$ python my_sample_code.py
[Taichi] version 1.7.1, llvm 15.0.4, commit 0f143b2f, linux, python 3.8.8
[Taichi] Starting on arch=x64
===================>
taichi kernel forward: None None
taichi kernel backward+1: tensor([0., 0., 0., 0.]) tensor([1.])
taichi kernel backward+2: tensor([0., 0., 0., 0.]) tensor([1.]) tensor([1.])
taichi kernel backward+3: tensor([0., 2., 4., 6.]) tensor([1.]) tensor([1.])
y.grad final tensor([0.])
x.grad final tensor([ 0.,  4.,  8., 12.])

Additional comments If possible, please also consider attaching the output of command ti diagnose. This produces the detailed environment information and hopefully helps us diagnose faster.

If you have local commits (e.g. compile fixes before you reproduce the bug), please make sure you first make a PR to fix the build errors and then report the bug.

zjcs commented 1 month ago

the gradient of x and y is wrong in the log, while the right result should be: y.grad: tensor([0.]) -> tensor([1.]) x.grad: tensor([ 0., 4., 8., 12.]) -> tensor([0., 2., 4., 6.])

bobcao3 commented 1 week ago

Taichi is accumulating directly into the gradient tensor. For correct interop behavior with pytorch you need to declare new zeroed gradient tensor and pass them into taichi, and then return those

oliver-batchelor commented 1 day ago

Relates to #8339 - IMO ideally Taichi should not touch the .grad attribute at all and use someother attribute or method to pass around gradients.

If you are careful you can replace the .grad vector with zeros before the taichi grad kernel call then afterwards restore whatever was in the .grad vector and it works.