yanzq95 / DHD

Deep Height Decoupling for Precise Vision-based 3D Occupancy Prediction
Apache License 2.0
27 stars 3 forks source link

RuntimeError: DeformConv is not compiled with GPU support #5

Open zhaosonghui opened 3 days ago

zhaosonghui commented 3 days ago

This is an exciting work! When I follow the process to reproduce the code, I encounter the error:“/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/ops/deform_conv.py", line 92, in forward ext_module.deform_conv_forward,RuntimeError: DeformConv is not compiled with GPU support”,However, mmcv is already installed in my environment. I would like to ask if you encountered this error while training the model ?


2024-11-28 16:01:08,419 - mmdet - INFO - workflow: [('train', 1)], max: 24 epochs 2024-11-28 16:01:08,420 - mmdet - INFO - Checkpoints will be saved to /home/zh/DHD/work_dirs/DHD-S by HardDiskBackend. /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) /home/zh/DHD/projects/mmdet3d_plugin/datasets/pipelines/loading.py:361: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.) gt_boxes, gt_labels = torch.Tensor(gt_boxes), torch.tensor(gt_labels) Traceback (most recent call last): File "tools/train.py", line 287, in main() File "tools/train.py", line 276, in main train_model( File "/home/zh/DHD/mmdetection3d/mmdet3d/apis/train.py", line 344, in train_model train_detector( File "/home/zh/DHD/mmdetection3d/mmdet3d/apis/train.py", line 319, in train_detector runner.run(data_loaders, cfg.workflow) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 130, in run epoch_runner(data_loaders[i], kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 51, in train self.run_iter(data_batch, train_mode=True, kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 29, in run_iter outputs = self.model.train_step(data_batch, self.optimizer, File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/parallel/distributed.py", line 59, in train_step output = self.module.train_step(inputs[0], kwargs[0]) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 248, in train_step losses = self(data) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 116, in new_func return old_func(*args, kwargs) File "/home/zh/DHD/mmdetection3d/mmdet3d/models/detectors/base.py", line 60, in forward return self.forward_train(kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/detectors/DHD_model.py", line 175, in forward_train x_2d, x_3d, pts_feats, depth, height = self.extract_feat (points, img_inputs=img_inputs, img_metas=img_metas, File "/home/zh/DHD/projects/mmdet3d_plugin/models/detectors/DHD_model.py", line 131, in extract_feat x_2d, x_3d, depth, height = self.extract_img_feat (img_inputs, img_metas, *kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/detectors/DHD_model.py", line 103, in extract_img_feat x_2d, depth, height, mask_1, mask_2, mask_3 = self.img_view_transformer ([x, sensor2keyegos, ego2globals, intrins, post_rots, File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/necks/lss_heightmap.py", line 487, in forward x_h = self.height_net(x, mlp_input, stereo_metas) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/model_utils/depthnet.py", line 651, in forward depth = self.depth_conv(depth) # x: (BN_views, C_mid, fH, fW) --> (BN_views, D, fH, fW) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward input = module(input) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/ops/deform_conv.py", line 379, in forward return deform_conv2d(x, offset, self.weight, self.stride, self.padding, File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/ops/deform_conv.py", line 92, in forward ext_module.deform_conv_forward( RuntimeError: DeformConv is not compiled with GPU support Traceback (most recent call last): File "tools/train.py", line 287, in main() File "tools/train.py", line 276, in main train_model( File "/home/zh/DHD/mmdetection3d/mmdet3d/apis/train.py", line 344, in train_model train_detector( File "/home/zh/DHD/mmdetection3d/mmdet3d/apis/train.py", line 319, in train_detector runner.run(data_loaders, cfg.workflow) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 130, in run epoch_runner(data_loaders[i], kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 51, in train self.run_iter(data_batch, train_mode=True, kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 29, in run_iter outputs = self.model.train_step(data_batch, self.optimizer, File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/parallel/distributed.py", line 59, in train_step output = self.module.train_step(inputs[0], kwargs[0]) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 248, in train_step losses = self(data) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 116, in new_func return old_func(*args, kwargs) File "/home/zh/DHD/mmdetection3d/mmdet3d/models/detectors/base.py", line 60, in forward return self.forward_train(kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/detectors/DHD_model.py", line 175, in forward_train x_2d, x_3d, pts_feats, depth, height = self.extract_feat (points, img_inputs=img_inputs, img_metas=img_metas, File "/home/zh/DHD/projects/mmdet3d_plugin/models/detectors/DHD_model.py", line 131, in extract_feat x_2d, x_3d, depth, height = self.extract_img_feat (img_inputs, img_metas, *kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/detectors/DHD_model.py", line 103, in extract_img_feat x_2d, depth, height, mask_1, mask_2, mask_3 = self.img_view_transformer ([x, sensor2keyegos, ego2globals, intrins, post_rots, File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/necks/lss_heightmap.py", line 487, in forward x_h = self.height_net(x, mlp_input, stereo_metas) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/zh/DHD/projects/mmdet3d_plugin/models/model_utils/depthnet.py", line 651, in forward depth = self.depth_conv(depth) # x: (BN_views, C_mid, fH, fW) --> (BN_views, D, fH, fW) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward input = module(input) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/ops/deform_conv.py", line 379, in forward return deform_conv2d(x, offset, self.weight, self.stride, self.padding, File "/home/zh/anaconda3/envs/DHD1/lib/python3.8/site-packages/mmcv/ops/deform_conv.py", line 92, in forward ext_module.deform_conv_forward( RuntimeError: DeformConv is not compiled with GPU support ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 54415) of binary: /home/zh/anaconda3/envs/DHD1/bin/python

Rayn-Wu commented 3 days ago

Hi, we haven't encountered such an error before.

Please check whether the version of MMCV you installed is compatible with your CUDA version. We recommend using the environment as listed in the README:

pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.5.3
pip install mmdet==2.25.1
pip install mmsegmentation==0.25.0
zhaosonghui commented 3 days ago

Thank you very much for your prompt reply! I configured the runtime environment according to the installation instructions in the code. Could you please provide the details of the packages installed in your environment? I would like to compare. My runtime environment is as follows:

Name Version Build Channel

_libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu absl-py 2.1.0 pypi_0 pypi addict 2.4.0 pypi_0 pypi black 24.8.0 pypi_0 pypi ca-certificates 2024.9.24 h06a4308_0 cachetools 5.5.0 pypi_0 pypi certifi 2024.8.30 pypi_0 pypi charset-normalizer 3.4.0 pypi_0 pypi click 8.1.7 pypi_0 pypi contourpy 1.1.1 pypi_0 pypi cycler 0.12.1 pypi_0 pypi decorator 5.1.1 pypi_0 pypi descartes 1.1.0 pypi_0 pypi dhd-plugin 0.0.0 dev_0 exceptiongroup 1.2.2 pypi_0 pypi fire 0.7.0 pypi_0 pypi flake8 7.1.1 pypi_0 pypi fonttools 4.55.0 pypi_0 pypi google-auth 2.36.0 pypi_0 pypi google-auth-oauthlib 1.0.0 pypi_0 pypi grpcio 1.68.0 pypi_0 pypi idna 3.10 pypi_0 pypi imageio 2.35.1 pypi_0 pypi importlib-metadata 8.5.0 pypi_0 pypi importlib-resources 6.4.5 pypi_0 pypi iniconfig 2.0.0 pypi_0 pypi joblib 1.4.2 pypi_0 pypi kiwisolver 1.4.7 pypi_0 pypi lazy-loader 0.4 pypi_0 pypi ld_impl_linux-64 2.40 h12ee557_0 libffi 3.3 he6710b0_2 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libstdcxx-ng 11.2.0 h1234567_1 llvmlite 0.36.0 pypi_0 pypi lyft-dataset-sdk 0.0.8 pypi_0 pypi markdown 3.7 pypi_0 pypi markupsafe 2.1.5 pypi_0 pypi matplotlib 3.5.3 pypi_0 pypi mccabe 0.7.0 pypi_0 pypi mmcls 0.25.0 pypi_0 pypi mmcv-full 1.5.3 pypi_0 pypi mmdet 2.25.1 pypi_0 pypi mmdet3d 1.0.0rc4 dev_0 mmsegmentation 0.25.0 pypi_0 pypi mypy-extensions 1.0.0 pypi_0 pypi ncurses 6.4 h6a678d5_0 networkx 2.2 pypi_0 pypi ninja 1.11.1.2 pypi_0 pypi numba 0.53.0 pypi_0 pypi numpy 1.23.4 pypi_0 pypi nuscenes-devkit 1.1.11 pypi_0 pypi oauthlib 3.2.2 pypi_0 pypi opencv-python 4.10.0.84 pypi_0 pypi openssl 1.1.1w h7f8727e_0 packaging 24.2 pypi_0 pypi pandas 2.0.3 pypi_0 pypi pathspec 0.12.1 pypi_0 pypi pillow 10.4.0 pypi_0 pypi pip 24.2 py38h06a4308_0 platformdirs 4.3.6 pypi_0 pypi plotly 5.24.1 pypi_0 pypi pluggy 1.5.0 pypi_0 pypi plyfile 1.0.3 pypi_0 pypi prettytable 3.11.0 pypi_0 pypi protobuf 5.29.0 pypi_0 pypi pyasn1 0.6.1 pypi_0 pypi pyasn1-modules 0.4.1 pypi_0 pypi pycocotools 2.0.7 pypi_0 pypi pycodestyle 2.12.1 pypi_0 pypi pyflakes 3.2.0 pypi_0 pypi pyparsing 3.1.4 pypi_0 pypi pyquaternion 0.9.9 pypi_0 pypi pytest 8.3.3 pypi_0 pypi python 3.8.13 haa1d7c7_1 python-dateutil 2.9.0.post0 pypi_0 pypi pytz 2024.2 pypi_0 pypi pywavelets 1.4.1 pypi_0 pypi pyyaml 6.0.2 pypi_0 pypi readline 8.2 h5eee18b_0 requests 2.32.3 pypi_0 pypi requests-oauthlib 2.0.0 pypi_0 pypi rsa 4.9 pypi_0 pypi scikit-image 0.19.3 pypi_0 pypi scikit-learn 1.3.2 pypi_0 pypi scipy 1.10.1 pypi_0 pypi setuptools 59.5.0 pypi_0 pypi shapely 1.8.5.post1 pypi_0 pypi six 1.16.0 pypi_0 pypi sqlite 3.45.3 h5eee18b_0 tenacity 9.0.0 pypi_0 pypi tensorboard 2.14.0 pypi_0 pypi tensorboard-data-server 0.7.2 pypi_0 pypi termcolor 2.4.0 pypi_0 pypi terminaltables 3.1.10 pypi_0 pypi threadpoolctl 3.5.0 pypi_0 pypi tifffile 2023.7.10 pypi_0 pypi tk 8.6.14 h39e8969_0 tomli 2.2.1 pypi_0 pypi torch 1.10.0+cu113 pypi_0 pypi torchaudio 0.10.0+rocm4.1 pypi_0 pypi torchvision 0.11.0+cu113 pypi_0 pypi tqdm 4.67.1 pypi_0 pypi trimesh 2.35.39 pypi_0 pypi typing-extensions 4.12.2 pypi_0 pypi tzdata 2024.2 pypi_0 pypi urllib3 2.2.3 pypi_0 pypi wcwidth 0.2.13 pypi_0 pypi werkzeug 3.0.6 pypi_0 pypi wheel 0.44.0 py38h06a4308_0 xz 5.4.6 h5eee18b_1 yapf 0.40.1 pypi_0 pypi zipp 3.20.2 pypi_0 pypi zlib 1.2.13 h5eee18b_1

Rayn-Wu commented 3 days ago

Packages in our env are:

_libgcc_mutex             0.1                        main  
absl-py                   2.0.0                     <pip>
addict                    2.4.0                     <pip>
appdirs                   1.4.4                     <pip>
black                     23.11.0                   <pip>
ca-certificates           2023.08.22           h06a4308_0  
cachetools                5.3.2                     <pip>
certifi                   2023.11.17                <pip>
charset-normalizer        3.3.2                     <pip>
click                     8.1.7                     <pip>
contourpy                 1.1.1                     <pip>
cycler                    0.12.1                    <pip>
decorator                 5.1.1                     <pip>
descartes                 1.1.0                     <pip>
exceptiongroup            1.2.0                     <pip>
ez_setup                  0.9                       <pip>
fire                      0.5.0                     <pip>
flake8                    6.1.0                     <pip>
flatbuffers               23.5.26                   <pip>
fonttools                 4.45.1                    <pip>
google-auth               2.23.4                    <pip>
google-auth-oauthlib      1.0.0                     <pip>
grpcio                    1.59.3                    <pip>
idna                      3.6                       <pip>
imageio                   2.33.0                    <pip>
importlib-metadata        6.8.0                     <pip>
importlib-resources       6.1.1                     <pip>
iniconfig                 2.0.0                     <pip>
joblib                    1.3.2                     <pip>
kiwisolver                1.4.5                     <pip>
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.3                  he6710b0_2  
libgcc-ng                 9.1.0                hdf63c60_0  
libstdcxx-ng              9.1.0                hdf63c60_0  
llvmlite                  0.36.0                    <pip>
lyft-dataset-sdk          0.0.8                     <pip>
Mako                      1.3.0                     <pip>
Markdown                  3.5.1                     <pip>
MarkupSafe                2.1.3                     <pip>
matplotlib                3.5.3                     <pip>
mccabe                    0.7.0                     <pip>
mmcls                     0.25.0                    <pip>
mmcv-full                 1.5.3                     <pip>
mmdet                     2.25.1                    <pip>
mmsegmentation            0.25.0                    <pip>
mypy-extensions           1.0.0                     <pip>
ncurses                   6.3                  h7f8727e_2  
networkx                  2.2                       <pip>
ninja                     1.11.1.1                  <pip>
numba                     0.53.0                    <pip>
numpy                     1.23.4                    <pip>
nuscenes-devkit           1.1.11                    <pip>
oauthlib                  3.2.2                     <pip>
onnxruntime-gpu           1.8.1                     <pip>
opencv-python             4.8.1.78                  <pip>
openssl                   1.1.1w               h7f8727e_0  
packaging                 23.2                      <pip>
pandas                    2.0.3                     <pip>
pathspec                  0.11.2                    <pip>
Pillow                    10.1.0                    <pip>
pip                       23.3.1           py38h06a4308_0  
pip                       23.3.2                    <pip>
platformdirs              4.0.0                     <pip>
plotly                    5.18.0                    <pip>
pluggy                    1.3.0                     <pip>
plyfile                   1.0.2                     <pip>
prettytable               3.9.0                     <pip>
protobuf                  4.25.1                    <pip>
pyasn1                    0.5.1                     <pip>
pyasn1-modules            0.3.0                     <pip>
pycocotools               2.0.7                     <pip>
pycodestyle               2.11.1                    <pip>
pycuda                    2023.1                    <pip>
pyflakes                  3.1.0                     <pip>
pyparsing                 3.1.1                     <pip>
pyquaternion              0.9.9                     <pip>
pytest                    7.4.3                     <pip>
python                    3.8.13               h12debd9_0  
python-dateutil           2.8.2                     <pip>
pytools                   2023.1.1                  <pip>
pytz                      2023.3.post1              <pip>
PyWavelets                1.4.1                     <pip>
PyYAML                    6.0.1                     <pip>
readline                  8.1.2                h7f8727e_1  
requests                  2.31.0                    <pip>
requests-oauthlib         1.3.1                     <pip>
rsa                       4.9                       <pip>
scikit-image              0.19.3                    <pip>
scikit-learn              1.3.2                     <pip>
scipy                     1.10.1                    <pip>
setuptools                68.0.0           py38h06a4308_0  
setuptools                59.5.0                    <pip>
Shapely                   1.8.5.post1               <pip>
six                       1.16.0                    <pip>
sqlite                    3.38.5               hc218d9a_0  
tenacity                  8.2.3                     <pip>
tensorboard               2.14.0                    <pip>
tensorboard-data-server   0.7.2                     <pip>
termcolor                 2.3.0                     <pip>
terminaltables            3.1.10                    <pip>
threadpoolctl             3.2.0                     <pip>
tifffile                  2023.7.10                 <pip>
tk                        8.6.12               h1ccaba5_0  
tomli                     2.0.1                     <pip>
torch                     1.10.0+cu113              <pip>
torchaudio                0.10.0+rocm4.1            <pip>
torchvision               0.11.0+cu113              <pip>
tqdm                      4.66.1                    <pip>
trimesh                   2.35.39                   <pip>
typing_extensions         4.8.0                     <pip>
tzdata                    2023.3                    <pip>
urllib3                   2.1.0                     <pip>
wcwidth                   0.2.12                    <pip>
Werkzeug                  3.0.1                     <pip>
wheel                     0.41.2           py38h06a4308_0  
xz                        5.2.5                h7f8727e_1  
yapf                      0.40.0                    <pip>
zipp                      3.17.0                    <pip>
zlib                      1.2.12               h7f8727e_2
zhaosonghui commented 3 days ago

Thanks!

zhaosonghui commented 21 hours ago

Thank you very much! I've successfully run the code. Due to memory limitations, I set the batch size to 1 and trained DHD-L on two 4090 GPUs, which took about 8 days. I would like to ask if this training time is normal. What GPU did you use for training, and how long did it take? image

Rayn-Wu commented 18 hours ago
  1. It is normal when you use two gpus.
  2. We use six 4090 gpus with two samples per gpu. It takes around three days.
zhaosonghui commented 5 hours ago

I am also using 4090 GPUs. When training the DHD-L, if I set 'samples_per_gpu=1', each GPU occupies about 20GB of memory. However, setting 'samples_per_gpu=2' causes the usage to exceed the available memory on the GPU. The code originally set 'samples_per_gpu=2, workers_per_gpu=4'. Did you use this configuration for training DHD-L?

Rayn-Wu commented 2 hours ago

Yes, we use this config for DHD-L. But we didn't encounter OOM error.