ziplab / LITv2

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "Fast Vision Transformers with HiLo Attention"
Apache License 2.0
229 stars 11 forks source link

Fine-tuning classification on a custom dataset with N classes #11

Closed ytring closed 1 year ago

ytring commented 1 year ago

Hi,

Can you please tell me how I can fine-tune the classification model on a custom dataset with N number of classes? If I try to modify the number of classes in data/build.py, then I get the following error:

../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [23,0,0] Assertion `t >= 0 && t < n_classes` failed.
terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Exception raised from createEvent at ../aten/src/ATen/cuda/CUDAEvent.h:166 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7ff6ce16020e in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xf3cbd (0x7ff710a22cbd in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #2: <unknown function> + 0xf6ffe (0x7ff710a25ffe in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #3: <unknown function> + 0x463338 (0x7ff71fd88338 in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #4: c10::TensorImpl::release_resources() + 0x175 (0x7ff6ce1477a5 in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #5: <unknown function> + 0x35f355 (0x7ff71fc84355 in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x678d38 (0x7ff71ff9dd38 in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #7: THPVariable_subclass_dealloc(_object*) + 0x2b5 (0x7ff71ff9e0e5 in /home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #8: /home/tony/anaconda3/envs/lit/bin/python() [0x4dab0c]
frame #9: /home/tony/anaconda3/envs/lit/bin/python() [0x4f4fab]
frame #10: /home/tony/anaconda3/envs/lit/bin/python() [0x5c33ff]
frame #11: _PyEval_EvalFrameDefault + 0x5c24 (0x4ed314 in /home/tony/anaconda3/envs/lit/bin/python)
frame #12: /home/tony/anaconda3/envs/lit/bin/python() [0x4f7ec3]
frame #13: _PyEval_EvalFrameDefault + 0x301d (0x4ea70d in /home/tony/anaconda3/envs/lit/bin/python)
frame #14: /home/tony/anaconda3/envs/lit/bin/python() [0x4e67ea]
frame #15: _PyFunction_Vectorcall + 0xd5 (0x4f7be5 in /home/tony/anaconda3/envs/lit/bin/python)
frame #16: _PyEval_EvalFrameDefault + 0x3ce (0x4e7abe in /home/tony/anaconda3/envs/lit/bin/python)
frame #17: /home/tony/anaconda3/envs/lit/bin/python() [0x4f7ec3]
frame #18: _PyEval_EvalFrameDefault + 0x3ce (0x4e7abe in /home/tony/anaconda3/envs/lit/bin/python)
frame #19: /home/tony/anaconda3/envs/lit/bin/python() [0x4e67ea]
frame #20: _PyEval_EvalCodeWithName + 0x47 (0x4e6477 in /home/tony/anaconda3/envs/lit/bin/python)
frame #21: PyEval_EvalCodeEx + 0x39 (0x4e6429 in /home/tony/anaconda3/envs/lit/bin/python)
frame #22: PyEval_EvalCode + 0x1b (0x593ccb in /home/tony/anaconda3/envs/lit/bin/python)
frame #23: /home/tony/anaconda3/envs/lit/bin/python() [0x5c1077]
frame #24: /home/tony/anaconda3/envs/lit/bin/python() [0x5bd080]
frame #25: /home/tony/anaconda3/envs/lit/bin/python() [0x4564f6]
frame #26: PyRun_SimpleFileExFlags + 0x1a2 (0x5b6d62 in /home/tony/anaconda3/envs/lit/bin/python)
frame #27: Py_RunMain + 0x37e (0x5b42de in /home/tony/anaconda3/envs/lit/bin/python)
frame #28: Py_BytesMain + 0x39 (0x587d79 in /home/tony/anaconda3/envs/lit/bin/python)
frame #29: __libc_start_main + 0xf3 (0x7ff75366a083 in /lib/x86_64-linux-gnu/libc.so.6)
frame #30: /home/tony/anaconda3/envs/lit/bin/python() [0x587c2e]

ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 1259091) of binary: /home/tony/anaconda3/envs/lit/bin/python
Traceback (most recent call last):
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/distributed/launch.py", line 193, in <module>
    main()
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/distributed/launch.py", line 189, in main
    launch(args)
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/distributed/launch.py", line 174, in launch
    run(args)
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/distributed/run.py", line 752, in run
    elastic_launch(
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/tony/anaconda3/envs/lit/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

Thank you.

ytring commented 1 year ago

Solved the problem by ensuring that the number of classes in the val directory is the same as in the train directory.