shader-slang / slang-python

Superseded by github.com/shader-slang/slang-torch
MIT License
26 stars 3 forks source link

Using [Differentiable] without DiffTensorView #12

Open bprb opened 7 months ago

bprb commented 7 months ago

Hello!

By mistake, I added [Differentiable] to the square sample but I didn't change TensorView to DiffTensorView.

The result was a square.fwd() that simply calculates square(). I was wondering if there is ever a valid use case for doing this, and if not, if perhaps the compiler could catch this mistake? Because otherwise, if I pass a single input and single output, also by mistake, then there are no errors, but the result is definitely not the forward gradient :)

In detail, this shader...

[AutoPyBindCUDA]
[CUDAKernel]
[Differentiable]
void square(TensorView<float> input, TensorView<float> output)
{
// ...
    output[dispatchIdx.x] = input[dispatchIdx.x] * input[dispatchIdx.x];
}

... compiles without error, but fwd is not the gradient:

__device__ void s_fwd_square_0(TensorView input_2, TensorView output_2)
{
    uint _S14 = (((blockIdx)) * ((blockDim)) + ((threadIdx))).x;
    uint _S15 = ((input_2).sizes[(0U)]);
    if(_S14 >= _S15)
    {
        return;
    }
    float _S16 = ((input_2).load<float>((_S14)));
    float _S17 = ((input_2).load<float>((_S14)));
    (output_2).store<float>((_S14), (_S16 * _S17));   // ?!
    return;
}

extern "C" {
__global__ void __kernel__square_fwd_diff(TensorView _S18, TensorView _S19)
{
    s_fwd_square_0(_S18, _S19);
    return;
}

So in this python code, forward is the same as the squared numbers:

inputs = torch.tensor( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), dtype=torch.float).cuda()

squared = torch.zeros_like(inputs).cuda()
forward = torch.ones_like(inputs).cuda()

m.square(input=inputs, output=squared).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))
m.square.fwd(input=inputs, output=forward).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))

print(inputs)
print(squared)
print(forward)

Output:

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], device='cuda:0')
tensor([  1.,   4.,   9.,  16.,  25.,  36.,  49.,  64.,  81., 100.], device='cuda:0')
tensor([  1.,   4.,   9.,  16.,  25.,  36.,  49.,  64.,  81., 100.], device='cuda:0')    # oh noes

Everything is fine with the correct signature in the shader...

[AutoPyBindCUDA]
[CUDAKernel]
[Differentiable]
void square(DiffTensorView<float> input, DiffTensorView<float> output)
{
    uint3 dispatchIdx = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
    if (dispatchIdx.x >= input.size(0))
        return;

    output[dispatchIdx.x] = input[dispatchIdx.x] * input[dispatchIdx.x];
}

... and called as ...

inputs = torch.tensor( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), dtype=torch.float).cuda()
input_grad = torch.ones_like(inputs).cuda()

squared = torch.zeros_like(inputs).cuda()
forward = torch.ones_like(inputs).cuda()

m.square(input=inputs, output=squared).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))
m.square.fwd(input=(inputs, input_grad), output=(squared,forward)).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))

So I was wondering if this user error could be caught? Perhaps it's also useful to add a fwd example to the docs to make it very clear that, just like bwd, it expects a pair?

Thanks! bert

(Edit: rewrote for clarity)

saipraveenb25 commented 7 months ago

Thank you for bringing it to our attention!

This is unfortunately a tricky situation. Since we do want to allow regular TensorView types for passing non-differentiable data, it would be confusing to throw errors/warnings when using TensorView with a [Differentiable] function.

As you suggested, for now, we'll add an example with .fwd() to the documentation to make the interface clearer, i.e., .fwd() computes both the regular output & its derivative

bprb commented 7 months ago

Hi Sai, thank you for the quick reply! I see, this is a valid thing to do. We'll just have to be careful then :) Thanks for taking the suggestion about the docs!