Closed GorkaUrbizu closed 4 years ago
Thanks for spotting this. I should update the readme to state these tutorials now use PyTorch 1.3.
I just find a nice workaround by change
return ((y_pred > thresh) == y_true.bool()).float().mean().item()
to
return ((y_pred > thresh) == (y_true >
0.5)).float().mean().item()
and it works perfectly.
Hi!
When I tried to run the 6th transformer tutorial in my GPU, I got the following error:
This may happen if you use a previous version to PyTorch 1.2, 1.0.1 in my case. To solve this, I changed that line as suggested at this issue to this:
trg_sub_mask = torch.tril(torch.ones((trg_len,trg_len),device=self.device)).type(torch.uint8)
With this change it runs without problems and it seems that it works correctly.
I open this issue to see if anyone sees any problem with my solution, and mainly to help to anyone facing the same problem.
Gorka