NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.63k stars 962 forks source link

[QST] Why _CUTLASS_TYPE_TO_TORCH_TYPE doesn't support torch.bfloat16? #1736

Open hxdtest opened 2 months ago

hxdtest commented 2 months ago

What is your question? Inpython/cutlass/emit/pytorch.py, bfloat16 is not supported?

_CUTLASS_TYPE_TO_TORCH_TYPE = {
    DataType.f16: "torch::kF16",
    DataType.f32: "torch::kF32",
    DataType.f64: "torch::kF64",
    DataType.s8: "torch::I8",
    DataType.s32: "torch::I32",
}
jackkosaian commented 2 months ago

We simply haven't implemented it. We welcome contributions in this space.

Bogumil-Sapinski-Mobica commented 1 month ago

Hi @jackkosaian could I take care of it? Could you assign me?

jackkosaian commented 1 month ago

Yes, please feel free to submit a PR supporting this.

Bogumil-Sapinski-Mobica commented 1 month ago

Hi, I have added PR https://github.com/NVIDIA/cutlass/pull/1843 I added it because:

However I have some doubts:

I was not able to run 02_pytorch_extension_grouped_gemm.ipynb with my changes so I have no living proof that it works at least in this example. Any advice how to import pytorch from local repo will be welcome.