warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
cuBLAS API failed with status 15
error detected/nfs_users/users/ali.filali/miniconda3/envs/trl/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A: torch.Size([263, 1088]), B: torch.Size([3264, 1088]), C: (263, 3264); (lda, ldb, ldc): (c_int(8416), c_int(104448), c_int(8416)); (m, n, k): (c_int(263), c_int(3264), c_int(1088))
cuBLAS API failed with status 15
error detectedcuBLAS API failed with status 15
error detectedA: torch.Size([195, 1088]), B: torch.Size([3264, 1088]), C: (195, 3264); (lda, ldb, ldc): (c_int(6240), c_int(104448), c_int(6240)); (m, n, k): (c_int(195), c_int(3264), c_int(1088))
A: torch.Size([216, 1088]), B: torch.Size([3264, 1088]), C: (216, 3264); (lda, ldb, ldc): (c_int(6912), c_int(104448), c_int(6912)); (m, n, k): (c_int(216), c_int(3264), c_int(1088))
cuBLAS API failed with status 15
error detectedA: torch.Size([125, 1088]), B: torch.Size([3264, 1088]), C: (125, 3264); (lda, ldb, ldc): (c_int(4000), c_int(104448), c_int(4000)); (m, n, k): (c_int(125), c_int(3264), c_int(1088))
[rank3]: Traceback (most recent call last):
Expected behavior
I expect the training to start and finish in about 5 minutes similar to what happen when i run the following code with no --load_in_8bit true flag :
System Info
Accelerate
version: 0.34.2accelerate
bash location: ~/miniconda3/envs/trl/bin/accelerateAccelerate
default config: Not foundReproduction
Gives the following error :
Expected behavior
I expect the training to start and finish in about 5 minutes similar to what happen when i run the following code with no
--load_in_8bit true
flag :