BatsResearch / csp

Learning to compose soft prompts for compositional zero-shot learning.
BSD 3-Clause "New" or "Revised" License
83 stars 6 forks source link

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) #19

Closed ans92 closed 8 months ago

ans92 commented 8 months ago

Hi, I am facing an error as follows:

training details
Namespace(experiment_name='csp', dataset='mit-states', lr=5e-05, weight_decay=1e-05, clip_model='ViT-L/14', epochs=20, train_batch_size=4, eval_batch_size=4, evaluate_only=False, context_length=8, attr_dropout=0.3, save_path='/home/ans/CZSL/csp-model-saved/mit-states/sample_model', save_every_n=1, save_model=False, seed=0, gradient_accumulation_steps=2)
####
/home/ans/DATA_ROOT/mit-states/
# train pairs: 1262 | # val pairs: 600 | # test pairs: 800
# train images: 30338 | # val images: 10420 | # test images: 12995
model dtype torch.float16
soft embedding dtype torch.float32
epoch   1:   0%|                                                        | 0/7585 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/ans/CZSL/csp-main/train.py", line 218, in <module>
    model, optimizer = train_model(
  File "/home/ans/CZSL/csp-main/train.py", line 66, in train_model
    logits = model(batch_feat, train_pairs)
  File "/home/ans/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ans/CZSL/csp-main/clip_modules/interface.py", line 103, in forward
    token_tensors = self.construct_token_tensors(idx)
  File "/home/ans/CZSL/csp-main/models/csp.py", line 55, in construct_token_tensors
    token_tensor[:, eos_idx - 2, :] = soft_embeddings[
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
epoch   1:   0%|                                                        | 0/7585 [00:00<?, ?it/s]

attr_idx on line 56 is in cuda that is I think causing issues. Can you please help me about this? Thank you for your help.

WangZN11 commented 8 months ago

I also encountered the same issue, which may be caused by the PyTorch version. When I used pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 in a conda environment with Python 3.8, there were no errors. Alternatively, you can refer to the author's suggestion in the setup to install all dependencies using pip install -r requirements.txt.

nihalnayak commented 8 months ago

Let me know if this issue is still not resolved. Closing this issue for now.