Open Purewhite2019 opened 1 year ago
RuntimeError occurs when calling jt.matmul in WSL on an NVIDIA GeForce MX450 laptop.
RuntimeError
jt.matmul
[i 1031 23:09:50.689249 80 cuda_flags.cc:32] CUDA enabled. --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-3-b8f9d3eceb14> in <module> ----> 1 jt.flags.use_cuda = 1; jt.matmul(jt.rand((1, 1, 100, 100)), jt.rand(((1, 16, 100, 1)))).shape ~/anaconda3/envs/gm-jittor/lib/python3.7/site-packages/jittor/nn.py in matmul(a, b) 122 # a: [..., n, m], b: [..., m, k], c:[..., n, k] 123 if jt.flags.use_cuda and jt.compile_extern.cublas_ops: --> 124 return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0) 125 shape = [] 126 len_c = max(len_a, len_b) RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.cublas_batched_matmul)). Types of your inputs are: self = module, args = (Var, Var, int, int, ), The function declarations are: VarHolder* cublas_batched_matmul(VarHolder* a, VarHolder* b, bool trans_a, bool trans_b) Failed reason:[f 1031 23:09:50.690284 80 cublas_batched_matmul_op.cc:75] Check failed a->shape[i](1) == b->shape[i](16) Something wrong ... Could you please report this issue?
# On CPU, jt.matmul() works jt.flags.use_cuda = 0; jt.matmul(jt.rand((1, 1, 100, 100)), jt.rand(((1, 16, 100, 1)))).shape # On GPU, it doesn`t work jt.flags.use_cuda = 1; jt.matmul(jt.rand((1, 1, 100, 100)), jt.rand(((1, 16, 100, 1)))).shape
# Add repeat(1, 16, 1, 1) to explicitly specify the shape to broadcast jt.flags.use_cuda = 1; jt.matmul(jt.rand((1, 1, 100, 100)).repeat(1, 16, 1, 1), jt.rand(((1, 16, 100, 1)))).shape
However, this solution results in large computational error.
a = torch.randn(1,1,100,100).numpy() b = torch.randn(1,16,100,1).numpy() c_torch = torch.matmul(torch.tensor(a), torch.tensor(b)).numpy() c_jt = jt.matmul(jt.Var(a).repeat(1, 16, 1, 1), jt.Var(b)).numpy() np.testing.assert_almost_equal(c_torch, c_jt) # Passed jt.flags.use_cuda = 1 c_jt = jt.matmul(jt.Var(a).repeat(1, 16, 1, 1), jt.Var(b)).numpy() np.testing.assert_almost_equal(c_torch, c_jt) # AssertionError
AssertionError: Arrays are not almost equal to 7 decimals Mismatched elements: 1350 / 1600 (84.4%) Max absolute difference: 1.1444092e-05 Max relative difference: 3.5474255e-05 x: array([[[[ 4.539161 ], [ 5.9756746], [ 13.752586 ],... y: array([[[[ 4.53916 ], [ 5.9756722], [ 13.752583 ],...
jt.matmul should work correctly on CUDA.
Thank you for your advice, we will fix the shape issues in the upcoming updates.
Describe the bug
RuntimeError
occurs when callingjt.matmul
in WSL on an NVIDIA GeForce MX450 laptop.Full Log
Minimal Reproduce
Possible Solution
However, this solution results in large computational error.
Expected behavior
jt.matmul
should work correctly on CUDA.