thunlp / OpenPrompt

An Open-Source Framework for Prompt-Learning.
https://thunlp.github.io/OpenPrompt/
Apache License 2.0
4.38k stars 455 forks source link

[Tutorial 2.1 error] TypeError: where(): argument 'other' (position 3) must be Tensor, not int #172

Open canghongjian opened 2 years ago

canghongjian commented 2 years ago

It happened in tutorial 2.1. Details are as follows: Traceback (most recent call last): File "condional_prompt.py", line 112, in loss = prompt_model(inputs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, *kwargs) File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 449, in forward return self._forward(args, **kwargs) File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 467, in _forward logits, labels = self.shift_logits_and_labels(logits, batch['loss_ids'], reference_ids) File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 434, in shift_logits_and_labels shift_input_ids = torch.where(shift_loss_ids>0, shift_input_ids, -100) TypeError: where(): argument 'other' (position 3) must be Tensor, not int