It happened in tutorial 2.1. Details are as follows:
Traceback (most recent call last):
File "condional_prompt.py", line 112, in
loss = prompt_model(inputs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, *kwargs)
File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 449, in forward
return self._forward(args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 467, in _forward
logits, labels = self.shift_logits_and_labels(logits, batch['loss_ids'], reference_ids)
File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 434, in shift_logits_and_labels
shift_input_ids = torch.where(shift_loss_ids>0, shift_input_ids, -100)
TypeError: where(): argument 'other' (position 3) must be Tensor, not int
It happened in tutorial 2.1. Details are as follows: Traceback (most recent call last): File "condional_prompt.py", line 112, in
loss = prompt_model(inputs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, *kwargs)
File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 449, in forward
return self._forward(args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 467, in _forward
logits, labels = self.shift_logits_and_labels(logits, batch['loss_ids'], reference_ids)
File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 434, in shift_logits_and_labels
shift_input_ids = torch.where(shift_loss_ids>0, shift_input_ids, -100)
TypeError: where(): argument 'other' (position 3) must be Tensor, not int