naver / splade

SPLADE: sparse neural search (SIGIR21, SIGIR22)
Other
710 stars 79 forks source link

[Bug] Get PyTorch version #42

Closed leobavila closed 11 months ago

leobavila commented 11 months ago

Hi, I believe there is a bug in the function to check if Pytorch version >= 1.6.

https://github.com/naver/splade/blob/main/splade/tasks/amp.py

import torch

# inspired from Colbert repo: https://github.com/stanford-futuredata/ColBERT

PyTorch_over_1_6 = float((torch.__version__.split('.')[1])) >= 6 and float((torch.__version__.split('.')[0])) >= 1

It returns false for a pytorch version '2.0.1+cu117' (google colab). Could you guys check it please? I have replaced the function by another one:

PyTorch_over_1_6 = float(".".join([torch.__version__.split('.')[0], torch.__version__.split('.')[1]])) >= 1.6

Full example: image

This error makes the code to break when using this pytorch version combined with fp16 = True.

Error message: "Cannot use AMP for PyTorch version < 1.6"

From: image

leobavila commented 11 months ago

I have sent a PR to fix it.

cadurosar commented 11 months ago

Yes, this was indeed a bug, we had "fixed" in the upcoming version by removing the check, but this is actually better than what we had.

Thanks a lot!

(I'm closing as I merged, but feel free to open if I forgot something)