Open hxdtest opened 2 weeks ago
why torch.bfloat16 is not included?
_CUTLASS_TYPE_TO_TORCH_TYPE = {
DataType.f16: "torch::kF16",
DataType.f32: "torch::kF32",
DataType.f64: "torch::kF64",
DataType.s8: "torch::I8",
DataType.s32: "torch::I32",
}
Describe the bug I followed 02_pytorch_extension_grouped_gemm.ipynb. And I change dtype from torch.float16 to torch.bfloat16
it raises error
Is torch.bfloat16 not supportted?