NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.64k stars 961 forks source link

[BUG] Conv2D Frop Gives Wrong Output #1690

Open YixuanSeanZhou opened 3 months ago

YixuanSeanZhou commented 3 months ago

Describe the bug Conv2D Frop gives wrong output when calculating a small conv in FP32

Steps/Code to reproduce bug Simply run

import torch
import numpy as np
import cutlass

inputs = torch.from_numpy(np.array([[
    [[1 ,1], [1, 1]],
    [[2 ,2], [2, 2]],
    [[3 ,3], [3, 3]],
    [[0 ,0], [0, 0]],
]]))

inputs = inputs.to(torch.float32).to(memory_format=torch.channels_last)

weights = torch.from_numpy(np.array([[[[1]], [[1]], [[1]], [[1]]]]))
weights = weights.to(torch.float32).to(memory_format=torch.channels_last)

H, C, W, H = inputs.shape
K, C, R, S = weights.shape

N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)

N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)

dtype = torch.float32

plan = cutlass.Conv2dFprop(element_input=dtype, element_weight=dtype, element_C=dtype, element_output=dtype, element_accumulator=dtype)

tensor_C = torch.empty(size=(N, K, P, Q), dtype=dtype, device="cuda").to(memory_format=torch.channels_last)
output = torch.zeros_like(torch.empty(size=(N, K, P, Q), dtype=dtype, device="cuda").to(memory_format=torch.channels_last))
plan.run(inputs, weights, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)

print(output)

> 
tensor([[[[4., 6.],
          [6., 6.]]]], device='cuda:0')

Expected behavior

Output should be:

tensor([[[[6., 6.],
          [6., 6.]]]], device='cuda:0')

Environment details (please complete the following information):

Additional context Add any other context about the problem here.

github-actions[bot] commented 2 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.