Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.08k stars 311 forks source link

jt.scatter with reduce='add' occurs error when cuda is enabled #485

Open renwuli opened 1 year ago

renwuli commented 1 year ago

Describe the bug

jt.scatter(dim, index, src, reduce='add') raises error when cuda is enabled, while it works when cuda is disabled.

Full Log

CUDA version

[i 0803 23:25:57.342109 28 compiler.py:956] Jittor(1.3.8.5) src: /mnt/c/lirenwu/anaconda3/envs/jdev/lib/python3.7/site-packages/jittor
[i 0803 23:25:57.348883 28 compiler.py:957] g++ at /usr/bin/g++(7.5.0)
[i 0803 23:25:57.349004 28 compiler.py:958] cache_path: /home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default
[i 0803 23:25:57.355262 28 __init__.py:411] Found nvcc(10.1.105) at /usr/local/cuda-10.1/bin/nvcc.
[i 0803 23:25:57.477370 28 __init__.py:411] Found gdb(10.2) at /usr/bin/gdb.
[i 0803 23:25:57.509119 28 __init__.py:411] Found addr2line(2.34) at /usr/bin/addr2line.
[i 0803 23:25:57.830364 28 compiler.py:1011] cuda key:cu10.1.105_sm_75
[i 0803 23:25:58.520536 28 __init__.py:227] Total mem: 125.55GB, using 16 procs for compiling.
[i 0803 23:25:58.766577 28 jit_compiler.cc:28] Load cc_path: /usr/bin/g++
[i 0803 23:26:00.895891 28 init.cc:62] Found cuda archs: [75,]
[i 0803 23:26:01.213383 28 __init__.py:411] Found mpicc(4.0.3) at /usr/bin/mpicc.
[w 0803 23:26:01.534601 28 compile_extern.py:203] CUDA related path found in LD_LIBRARY_PATH or PATH(['/usr/local/cuda-10.1/lib64', '/usr/local/cuda/lib64/', '/home/leerw/.local/bin', '/usr/local/cuda-10.1/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/anaconda3/envs/jdev/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/home/yangzhipeng/Anaconda3/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/usr/local/java/latest/bin']), This path may cause jittor found the wrong libs, please unset LD_LIBRARY_PATH and remove cuda lib path in Path.
Or you can let jittor install cuda for you: `python3.x -m jittor_utils.install_cuda`
[w 0803 23:26:01.534741 28 compile_extern.py:203] CUDA related path found in LD_LIBRARY_PATH or PATH(['/usr/local/cuda-10.1/lib64', '/usr/local/cuda/lib64/', '/home/leerw/.local/bin', '/usr/local/cuda-10.1/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/anaconda3/envs/jdev/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/home/yangzhipeng/Anaconda3/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/usr/local/java/latest/bin']), This path may cause jittor found the wrong libs, please unset LD_LIBRARY_PATH and remove cuda lib path in Path.
Or you can let jittor install cuda for you: `python3.x -m jittor_utils.install_cuda`
[i 0803 23:26:04.822529 28 cuda_flags.cc:49] CUDA enabled.
/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default/cu10.1.105_sm_75/jit/setitem__OP_add__Td_int64__BMASK_1__Ti_int64__IDIM_1__ODIM_1__FOV_0__VD_1__IV0_0__IO0__1_____hash_5bf10ff00f2b05db_op.cc(40): error: no instance of overloaded function "atomicAdd" matches the argument list
            argument types are: (jittor::int64 *, jittor::int64)

1 error detected in the compilation of "/tmp/tmpxft_003bde59_00000000-8_setitem__OP_add__Td_int64__BMASK_1__Ti_int64__IDIM_1__ODIM_1__FOV_0__VD_1__IV0_0__IO0__1_____hash_5bf10ff00f2b05db_op.cpp1.ii".
Traceback (most recent call last):
  File "test.py", line 12, in <module>
    print(y)
  File "/mnt/c/lirenwu/anaconda3/envs/jdev/lib/python3.7/site-packages/jittor/__init__.py", line 2003, in vtos
    data_str = f"jt.Var({v.data}, dtype={v.dtype})"
RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.data)).

Types of your inputs are:
 self   = Var,

The function declarations are:
 inline DataView data()

Failed reason:[f 0803 23:26:07.244448 28 parallel_compiler.cc:331] Error happend during compilation:
 [Error] source file location:/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default/cu10.1.105_sm_75/jit/setitem__OP_add__Td_int64__BMASK_1__Ti_int64__IDIM_1__ODIM_1__FOV_0__VD_1__IV0_0__IO0__1_____hash_5bf10ff00f2b05db_op.cc
Compile operator(5/6)failed:Op(30:0:1:1:i3:o1:s0,setitem->31)

Reason: [f 0803 23:26:06.279614 04:C0 log.cc:608] Check failed ret(256) == 0(0) Run cmd failed: "/usr/local/cuda-10.1/bin/nvcc" "/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default/cu10.1.105_sm_75/jit/setitem__OP_add__Td_int64__BMASK_1__Ti_int64__IDIM_1__ODIM_1__FOV_0__VD_1__IV0_0__IO0__1_____hash_5bf10ff00f2b05db_op.cc"      -std=c++14 -Xcompiler -fPIC  -Xcompiler -march=native  -Xcompiler -fdiagnostics-color=always  -lstdc++ -ldl -shared  -I"/mnt/c/lirenwu/anaconda3/envs/jdev/lib/python3.7/site-packages/jittor/src" -I/mnt/c/lirenwu/anaconda3/envs/jdev/include/python3.7m -I/mnt/c/lirenwu/anaconda3/envs/jdev/include/python3.7m -DHAS_CUDA -DIS_CUDA -I"/usr/local/cuda-10.1/include" -I"/mnt/c/lirenwu/anaconda3/envs/jdev/lib/python3.7/site-packages/jittor/extern/cuda/inc"  -lcudart -L"/usr/local/cuda-10.1/lib64" -Xlinker -rpath="/usr/local/cuda-10.1/lib64"  -I"/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default/cu10.1.105_sm_75" -L"/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default/cu10.1.105_sm_75" -Xlinker -rpath="/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default/cu10.1.105_sm_75" -L"/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default" -Xlinker -rpath="/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default"  -l:"jit_utils_core.cpython-37m-x86_64-linux-gnu".so  -l:"jittor_core.cpython-37m-x86_64-linux-gnu".so  -x cu --cudart=shared -ccbin="/usr/bin/g++" --use_fast_math  -w  -I"/mnt/c/lirenwu/anaconda3/envs/jdev/lib/python3.7/site-packages/jittor/extern/cuda/inc"  -arch=compute_75  -code=sm_75   -o "/home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default/cu10.1.105_sm_75/jit/setitem__OP_add__Td_int64__BMASK_1__Ti_int64__IDIM_1__ODIM_1__FOV_0__VD_1__IV0_0__IO0__1_____hash_5bf10ff00f2b05db_op.so"

CPU version

[i 0803 23:22:13.533711 80 compiler.py:956] Jittor(1.3.8.5) src: /mnt/c/lirenwu/anaconda3/envs/jdev/lib/python3.7/site-packages/jittor
[i 0803 23:22:13.581102 80 compiler.py:957] g++ at /usr/bin/g++(7.5.0)
[i 0803 23:22:13.581692 80 compiler.py:958] cache_path: /home/leerw/.cache/jittor/jt1.3.8/g++7.5.0/py3.7.13/Linux-5.4.0-11xdd/IntelRXeonRSilxe9/default
[i 0803 23:22:13.628909 80 __init__.py:411] Found nvcc(10.1.105) at /usr/local/cuda-10.1/bin/nvcc.
[i 0803 23:22:13.899705 80 __init__.py:411] Found gdb(10.2) at /usr/bin/gdb.
[i 0803 23:22:13.908388 80 __init__.py:411] Found addr2line(2.34) at /usr/bin/addr2line.
[i 0803 23:22:15.249193 80 compiler.py:1011] cuda key:cu10.1.105_sm_75
[i 0803 23:22:16.420193 80 __init__.py:227] Total mem: 125.55GB, using 16 procs for compiling.
[i 0803 23:22:16.627473 80 jit_compiler.cc:28] Load cc_path: /usr/bin/g++
[i 0803 23:22:17.785336 80 init.cc:62] Found cuda archs: [75,]
[i 0803 23:22:18.133422 80 __init__.py:411] Found mpicc(4.0.3) at /usr/bin/mpicc.
[w 0803 23:22:18.473905 80 compile_extern.py:203] CUDA related path found in LD_LIBRARY_PATH or PATH(['/usr/local/cuda-10.1/lib64', '/usr/local/cuda/lib64/', '/home/leerw/.local/bin', '/usr/local/cuda-10.1/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/anaconda3/envs/jdev/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/home/yangzhipeng/Anaconda3/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/usr/local/java/latest/bin']), This path may cause jittor found the wrong libs, please unset LD_LIBRARY_PATH and remove cuda lib path in Path.
Or you can let jittor install cuda for you: `python3.x -m jittor_utils.install_cuda`
[w 0803 23:22:18.474047 80 compile_extern.py:203] CUDA related path found in LD_LIBRARY_PATH or PATH(['/usr/local/cuda-10.1/lib64', '/usr/local/cuda/lib64/', '/home/leerw/.local/bin', '/usr/local/cuda-10.1/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/anaconda3/envs/jdev/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/home/yangzhipeng/Anaconda3/bin', '/home/leerw/.local/bin', '/mnt/c/lirenwu/blender-3.5.1-linux-x64', '/mnt/c/lirenwu/nvim/bin', '/mnt/c/lirenwu/anaconda3/bin', '/mnt/c/lirenwu/anaconda3/condabin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/home/yangzhipeng/Anaconda3/bin', '/usr/local/sbin', '/usr/local/bin', '/usr/sbin', '/usr/bin', '/sbin', '/bin', '/usr/games', '/usr/local/games', '/snap/bin', '/usr/local/java/latest/bin', '/usr/local/java/latest/bin']), This path may cause jittor found the wrong libs, please unset LD_LIBRARY_PATH and remove cuda lib path in Path.
Or you can let jittor install cuda for you: `python3.x -m jittor_utils.install_cuda`

Compiling Operators(6/6) used: 3.35s eta:    0s
jt.Var([50 50], dtype=int64)
jt.Var([50 50], dtype=int64)

Minimal Reproduce

import jittor as jt

jt.flags.use_cuda = 1

dim = 0
x = jt.zeros((2,)).int64()
src = jt.ones((100,)).int64()
index = jt.zeros((100,)).int64()
index[50:] = 1

y = x.scatter(0, index, src, reduce='add')
print(y)
x.scatter_(0, index, src, reduce='add')
print(x)

Expected behavior

jt.Var([50 50], dtype=int64)
jt.Var([50 50], dtype=int64)
LDYang694 commented 5 months ago

It is recommended that you run with a newer cuda driver as well as the cuda toolkit, e.g. cuda11.2+cudnn8