shader-slang / slang-python

Superseded by github.com/shader-slang/slang-torch
MIT License
26 stars 3 forks source link

how to use DiffTensorView<float> in a struct #18

Closed brabbitdousha closed 4 months ago

brabbitdousha commented 4 months ago

Hi, here are my code and script, I can't pass grad when using DiffTensorView in struct

struct Reservoir
{
    DiffTensorView<float> input0;
    DiffTensorView<float> input1;
};
[AutoPyBindCUDA]
[CUDAKernel]
[Differentiable]
void square(Reservoir input, DiffTensorView output)
{
    uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim();
    if (dispatchIdx.x >= output.size(0) || dispatchIdx.y >= output.size(1)) return;

    output[dispatchIdx.x, dispatchIdx.y] = input.input0[dispatchIdx.x, dispatchIdx.y] * input.input0[dispatchIdx.x, dispatchIdx.y];
    output[dispatchIdx.x, dispatchIdx.y] += input.input1[dispatchIdx.x, dispatchIdx.y] * input.input1[dispatchIdx.x, dispatchIdx.y];

}
import torch
import slangpy

m = slangpy.loadModule("D:/codes/python/slang_test/test_struct/square.slang")

class MySquareFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):

        input['input0'] = input['input0'].contiguous()
        input['input1'] = input['input1'].contiguous()
        output = torch.zeros_like(input['input0'])
        output = output.contiguous()

        kernel_with_args = m.square(input=input, output=output)
        kernel_with_args.launchRaw(
            blockSize=(32, 32, 1),
            gridSize=((output.shape[0] + 31) // 32, (output.shape[1] + 31) // 32, 1))

        ctx.save_for_backward(input, output)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        (input, output) = ctx.saved_tensors

        input_grad = torch.zeros_like(input)
        grad_output = grad_output.contiguous()

        # Note: When using DiffTensorView, grad_output gets 'consumed' during the reverse-mode.
        # If grad_output may be reused, consider calling grad_output = grad_output.clone()
        #
        kernel_with_args = m.square.bwd(input=(input, input_grad), output=(output, grad_output))
        kernel_with_args.launchRaw(
            blockSize=(32, 32, 1),
            gridSize=((output.shape[0] + 31) // 32, (output.shape[1] + 31) // 32, 1))

        return input_grad

x = torch.tensor([[3.0, 4.0],[0.0, 1.0]], requires_grad=True, device='cuda')
y = torch.tensor([[5.0, 6.0],[7.0, 0.0]], requires_grad=True, device='cuda')
#print(f"X = {x}")
input = {'input0': x, 'input1': y}
y_pred = MySquareFunc.apply(input)
loss = y_pred.sum()
loss.backward()
print(f"dX = {x.grad.cpu()}")
print(f"dy = {y.grad.cpu()}")

errorelement 0 of tensors does not require grad and does not have a grad_fn occurs with loss.backward(), I tried to add IDifferentiable for the struct, but it can't complie right

brabbitdousha commented 4 months ago

sorry, it turns out that there is something wrong with the dict