NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.35k stars 903 forks source link

`02_pytorch_extension_grouped_gemm.ipynb` No kernel configuration found for supported data type and layout combination (<DataType.bf16: 16> #1757

Open hxdtest opened 2 weeks ago

hxdtest commented 2 weeks ago

Describe the bug I followed 02_pytorch_extension_grouped_gemm.ipynb. And I change dtype from torch.float16 to torch.bfloat16

import cutlass
import torch

dtype = torch.bfloat16
plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)
op = plan.construct()
grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)

it raises error

  File "/opt/conda/lib/python3.10/site-packages/cutlass/op/gemm.py", line 300, in _reset_operations
    raise Exception(f'No kernel configuration found for supported data type and layout '
Exception: No kernel configuration found for supported data type and layout combination (<DataType.bf16: 16>, <DataType.bf16: 16>, <DataType.bf16: 16>)x(<LayoutType.RowMajor: 2>, <LayoutType.RowMajor: 2>)

Is torch.bfloat16 not supportted?

hxdtest commented 2 weeks ago

why torch.bfloat16 is not included?

_CUTLASS_TYPE_TO_TORCH_TYPE = {
    DataType.f16: "torch::kF16",
    DataType.f32: "torch::kF32",
    DataType.f64: "torch::kF64",
    DataType.s8: "torch::I8",
    DataType.s32: "torch::I32",
}
jackkosaian commented 2 weeks ago

Please see https://github.com/NVIDIA/cutlass/issues/1736#issuecomment-2305319679