Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.07k stars 307 forks source link

运行JDet/run_net.py时报错 error: namespace "thrust" has no member "sequence"/“unique”/“sort” #568

Closed HandSomeGuy001 closed 1 month ago

HandSomeGuy001 commented 1 month ago

Describe the bug

我在使用JDet训练目标检测模型时出现报错 观察报错信息,我认为 error: namespace "thrust" has no member "sequence"/“unique”/“sort”是导致训练无法运行的关键,但不知道如何解决。运行 python -m jittor.test.test_cuda 可以正常运行。

我的运行环境是

Full Log

(base) root@intern-studio-50088800:~/JDet# python tools/run_net.py --config-file=./configs/s2anet/s2anet_r50_fpn_1x_dota_ridet.py --task=train
[i 0710 09:21:08.069771 28 compiler.py:956] Jittor(1.3.9.10) src: /root/.conda/lib/python3.11/site-packages/jittor
[i 0710 09:21:08.073042 28 compiler.py:957] g++ at /usr/bin/g++(9.4.0)
[i 0710 09:21:08.073148 28 compiler.py:958] cache_path: /root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default
[i 0710 09:21:08.078040 28 __init__.py:412] Found nvcc(12.2.140) at /usr/local/cuda/bin/nvcc.
[i 0710 09:21:08.081562 28 __init__.py:412] Found addr2line(2.34) at /usr/bin/addr2line.
[i 0710 09:21:08.602552 28 compiler.py:1011] cuda key:cu12.2.140_sm_80
[i 0710 09:21:10.793884 28 __init__.py:227] Total mem: 2014.99GB, using 16 procs for compiling.
[i 0710 09:21:12.343538 28 jit_compiler.cc:28] Load cc_path: /usr/bin/g++
[i 0710 09:21:12.870103 28 init.cc:63] Found cuda archs: [80,]
[i 0710 09:21:18.345738 28 cuda_flags.cc:49] CUDA enabled.
Loading config from:  ./configs/s2anet/s2anet_r50_fpn_1x_dota_ridet.py
[w 0710 09:21:19.627118 28 __init__.py:1625] load parameter fc.weight failed ...
[w 0710 09:21:19.627230 28 __init__.py:1625] load parameter fc.bias failed ...
[w 0710 09:21:19.627324 28 __init__.py:1644] load total 267 params, 2 failed
Wed Jul 10 09:21:19 2024 Start running

Compiling Operators(102/102) used: 3.31s eta:    0s 
/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_3__in0_dim_2__in0_type_int32__in1_dim_1__in1_type_int32__in2_dim_1__in2_type___hash_4278b4e1673a4c1d_op.cc(232): error: namespace "thrust" has no member "sequence"
                  thrust::sequence(array_ptr, array_ptr + dimlen);
                          ^

/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_3__in0_dim_2__in0_type_int32__in1_dim_1__in1_type_int32__in2_dim_1__in2_type___hash_4278b4e1673a4c1d_op.cc(233): error: namespace "thrust" has no member "unique"
                  int32_t num = thrust::unique(array_ptr, array_ptr + dimlen,
                                        ^

2 errors detected in the compilation of "/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_3__in0_dim_2__in0_type_int32__in1_dim_1__in1_type_int32__in2_dim_1__in2_type___hash_4278b4e1673a4c1d_op.cc".
/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_1__in0_dim_2__in0_type_int32__OUT_SIZE_1__out0_dim_1__out0_type_int32__HEADE___hash_74e193fbf19cd47a_op.cc(162): error: namespace "thrust" has no member "sort"
                      thrust::sort(thrust::device, indice_ptr, indice_ptr + dimlen,
                              ^

1 error detected in the compilation of "/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_1__in0_dim_2__in0_type_int32__OUT_SIZE_1__out0_dim_1__out0_type_int32__HEADE___hash_74e193fbf19cd47a_op.cc".
Traceback (most recent call last):
  File "/root/JDet/tools/run_net.py", line 62, in <module>
    main()
  File "/root/JDet/tools/run_net.py", line 53, in main
    runner.run()
  File "/root/JDet/python/jdet/runner/runner.py", line 84, in run
    self.train()
  File "/root/JDet/python/jdet/runner/runner.py", line 125, in train
    losses = self.model(images,targets)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.conda/lib/python3.11/site-packages/jittor/__init__.py", line 1203, in __call__
    return self.execute(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/JDet/python/jdet/models/networks/s2anet.py", line 35, in execute
    outputs = self.bbox_head(features, targets)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.conda/lib/python3.11/site-packages/jittor/__init__.py", line 1203, in __call__
    return self.execute(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/JDet/python/jdet/models/roi_heads/s2anet_head.py", line 627, in execute
    return self.loss(*outs,*self.parse_targets(targets))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/JDet/python/jdet/models/roi_heads/s2anet_head.py", line 349, in loss
    cls_reg_targets = anchor_target(
                      ^^^^^^^^^^^^^^
  File "/root/JDet/python/jdet/models/boxes/anchor_target.py", line 61, in anchor_target
    pos_inds_list, neg_inds_list) = multi_apply(
                                    ^^^^^^^^^^^^
  File "/root/JDet/python/jdet/utils/general.py", line 53, in multi_apply
    return tuple(map(list, zip(*map_results)))
                           ^^^^^^^^^^^^^^^^^
  File "/root/JDet/python/jdet/models/boxes/anchor_target.py", line 140, in anchor_target_single
    sampling_result = bbox_sampler.sample(assign_result, anchors,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/JDet/python/jdet/models/boxes/sampler.py", line 127, in sample
    pos_inds = jt.nonzero(assign_result.gt_inds > 0).squeeze(-1).unique()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.conda/lib/python3.11/site-packages/jittor/misc.py", line 677, in unique
    output, inverse = jt.code(
                      ^^^^^^^^
RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.ops.code)).

Types of your inputs are:
 self   = module,
 args   = (list, list, list, ),
 kwargs = {cpu_header=str, cpu_src=str, cuda_header=str, cuda_src=str, },

The function declarations are:
 VarHolder* code(NanoVector shape,  NanoString dtype, vector<VarHolder*>&& inputs={},  string&& cpu_src="",  vector<string>&& cpu_grad_src={},  string&& cpu_header="",  string&& cuda_src="",  vector<string>&& cuda_grad_src={},  string&& cuda_header="",  DataMap&& data={})
 vector_to_tuple<VarHolder*> code_(vector<NanoVector>&& shapes,  vector<NanoString>&& dtypes, vector<VarHolder*>&& inputs={},  string&& cpu_src="",  vector<string>&& cpu_grad_src={},  string&& cpu_header="",  string&& cuda_src="",  vector<string>&& cuda_grad_src={},  string&& cuda_header="",  DataMap&& data={})
 vector_to_tuple<VarHolder*> code__(vector<VarHolder*>&& inputs, vector<VarHolder*>&& outputs,  string&& cpu_src="",  vector<string>&& cpu_grad_src={},  string&& cpu_header="",  string&& cuda_src="",  vector<string>&& cuda_grad_src={},  string&& cuda_header="",  DataMap&& data={})

Failed reason:[f 0710 09:21:33.695908 28 parallel_compiler.cc:331] Error happend during compilation:
 [Error] source file location:/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_1__in0_dim_2__in0_type_int32__OUT_SIZE_1__out0_dim_1__out0_type_int32__HEADE___hash_74e193fbf19cd47a_op.cc
Compile operator(0/10)failed:Op(20056:0:1:1:i1:o1:s0:g1,code->20057)

Reason: [f 0710 09:21:32.976018 68:C0 log.cc:605] Check failed: ret>=0 && ret<=256  Run cmd failed: "/usr/local/cuda/bin/nvcc" "/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_1__in0_dim_2__in0_type_int32__OUT_SIZE_1__out0_dim_1__out0_type_int32__HEADE___hash_74e193fbf19cd47a_op.cc"      -std=c++14 -Xcompiler -fPIC  -Xcompiler -march=native  -Xcompiler -fdiagnostics-color=always  -lstdc++ -ldl -shared  -I"/root/.conda/lib/python3.11/site-packages/jittor/src" -I/root/.conda/include/python3.11 -I/root/.conda/include/python3.11 -DHAS_CUDA -DIS_CUDA -I"/usr/local/cuda/include" -I"/root/.conda/lib/python3.11/site-packages/jittor/extern/cuda/inc"  -lcudart -L"/usr/local/cuda/lib64" -Xlinker -rpath="/usr/local/cuda/lib64"  -I"/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80" -L"/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80" -Xlinker -rpath="/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80" -L"/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default" -Xlinker -rpath="/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default"  -l:"jit_utils_core.cpython-311-x86_64-linux-gnu".so  -l:"jittor_core.cpython-311-x86_64-linux-gnu".so  -x cu --cudart=shared -ccbin="/usr/bin/g++" --use_fast_math  -w  -I"/root/.conda/lib/python3.11/site-packages/jittor/extern/cuda/inc"  -arch=compute_80  -code=sm_80    --extended-lambda    -o "/root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default/cu12.2.140_sm_80/jit/code__IN_SIZE_1__in0_dim_2__in0_type_int32__OUT_SIZE_1__out0_dim_1__out0_type_int32__HEADE___hash_74e193fbf19cd47a_op.so" 
return 512. This might be an overcommit issue or out of memory. Try : sudo sysctl vm.overcommit_memory=1, or set enviroment variable `export DISABLE_MULTIPROCESSING=1`

(base) root@intern-studio-50088800:~/JDet# nvidia-smi
Wed Jul 10 09:22:39 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   34C    P0              63W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

(base) root@intern-studio-50088800:~/JDet# nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0

(base) root@intern-studio-50088800:~/JDet# g++ --version
g++ (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

(base) root@intern-studio-50088800:~# python -m jittor.test.test_cuda 
[i 0710 09:32:31.210559 88 compiler.py:956] Jittor(1.3.9.10) src: /root/.conda/lib/python3.11/site-packages/jittor
[i 0710 09:32:31.214289 88 compiler.py:957] g++ at /usr/bin/g++(9.4.0)
[i 0710 09:32:31.214391 88 compiler.py:958] cache_path: /root/.cache/jittor/jt1.3.9/g++9.4.0/py3.11.5/Linux-5.10.134xce/IntelRXeonRPlax86/c9b7/default
[i 0710 09:32:31.219581 88 __init__.py:412] Found nvcc(12.2.140) at /usr/local/cuda/bin/nvcc.
[i 0710 09:32:31.223893 88 __init__.py:412] Found addr2line(2.34) at /usr/bin/addr2line.
[i 0710 09:32:31.636232 88 compiler.py:1011] cuda key:cu12.2.140_sm_80
[i 0710 09:32:33.707237 88 __init__.py:227] Total mem: 2014.99GB, using 16 procs for compiling.
[i 0710 09:32:35.692703 88 jit_compiler.cc:28] Load cc_path: /usr/bin/g++
[i 0710 09:32:36.005842 88 init.cc:63] Found cuda archs: [80,]
[i 0710 09:32:39.792016 88 cuda_flags.cc:49] CUDA enabled.

Compiling Operators(1/1) used: 2.27s eta:    0s 
.[i 0710 09:32:44.402741 88 cuda_flags.cc:49] CUDA enabled.
.[i 0710 09:32:44.539658 88 cuda_flags.cc:49] CUDA enabled.

Compiling Operators(1/1) used: 3.57s eta:    0s 
.[i 0710 09:32:48.106413 88 cuda_flags.cc:49] CUDA enabled.
.s
----------------------------------------------------------------------
Ran 5 tests in 10.093s

OK (skipped=1)

Minimal Reproduce

Expected behavior

我希望模型能正常跑起来(

krauwu commented 1 month ago

jittor的unique方法似乎随着cuda thrust库的版本替换出现了问题; 要在jt.code 头文件中添加 #include <thrust/sequence.h>

include <thrust/sort.h>

include <thrust/unique.h>

514flowey commented 1 month ago

JDet是基于jittor1.3.6.3开发的,现在计图版本更新之后一些算子的使用方法可能有点错误。 建议暂时先基于jittor1.3.6.3以及python3.7使用JDet。 我们后续会更新JDet以适应jittor的版本更新。