Open jiaji-huang opened 1 year ago
Actually, a follow-up of my question, I think the issue can be resolved by installing from source.
Instead of pip install pytorch-fast-transformers
, cloning the repo and running python setup.py install
gets rid of the NaN I had before.
But I'd like to hear more thoughts.
Is this package tested on more recent cuda and pytorch versions?
My code calls
fast_transformers.causal_product
, which is actually the only function I call from this package.I setup this package with latest pytorch 1.13.0+cuda11.6, and get NaN errors at training. This, however, doesn't happen with the older pytorch 1.7.1+cuda11.0.