aredden / torch-cublas-hgemm

PyTorch half precision gemm lib w/ fused optional bias + optional relu/gelu
25 stars 1 forks source link

Matmul errors out when one tensor is batched and another isn't #1

Open rationalism opened 5 months ago

rationalism commented 5 months ago

Cool idea! Proud to submit a first bug report :)

This PyTorch code (Ubuntu, CUDA 12.1, Torch 2.2.2, Nvidia 4090):

>>> import cublas_ops
>>> import torch
>>> x = torch.ones([1, 2560, 8192], dtype=torch.float16, device="cuda:0")
>>> x
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]], device='cuda:0',
       dtype=torch.float16)
>>> y = torch.ones([8192, 28672], dtype=torch.float16, device="cuda:0")
>>> z = cublas_ops.cublas_half_matmul_batched_simple(x, y)

fails with this stack trace:

 ** On entry to HgemmStridedBatched parameter number 10 had an illegal value
cuBLAS API failed with status 7
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/cublas_ops/__init__.py", line 39, in cublas_half_matmul_batched_simple
    return _cublas_hgemm_batched_simple(a, b)
RuntimeError: cuBLAS API failed

but this code works:

>>> y = torch.ones([1, 8192, 28672], dtype=torch.float16, device="cuda:0")
>>> z = cublas_ops.cublas_half_matmul_batched_simple(x, y)
aredden commented 5 months ago

Ah, I'll look into it thanks!