Installed the current latest FARM release using pip install farm into a fresh conda environment with Python 3.8
This pulled in torch 1.7.1
I am getting the following exception:
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/farm/train.py", line 301, in train
per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/farm/modeling/adaptive_model.py", line 381, in logits_to_loss
all_losses = self.logits_to_loss_per_head(logits, **kwargs)
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/farm/modeling/adaptive_model.py", line 365, in logits_to_loss_per_head
all_losses.append(head.logits_to_loss(logits=logits_for_one_head, **kwargs))
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/farm/modeling/prediction_head.py", line 358, in logits_to_loss
return self.loss_fct(logits, label_ids.view(-1))
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 961, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/torch/nn/functional.py", line 2468, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/johann/.conda/envs/femdwell-debug-1/lib/python3.8/site-packages/torch/nn/functional.py", line 2264, in nll_loss
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: expected scalar type Float but found Double
Installed the current latest FARM release using
pip install farm
into a fresh conda environment with Python 3.8 This pulled in torch 1.7.1I am getting the following exception:
when calling trainer.train()