FMInference / DejaVu

268 stars 32 forks source link

PyTorch 1.12 and flash-attn==0.2.8 are not compatible. #27

Open heheda12345 opened 5 months ago

heheda12345 commented 5 months ago

Thanks for your great work! I am trying to reproduce the latency tests with the scripts in Dejavu/benchmarks folder. I've installed the recommended PyTorch 1.12 and flash-attn==0.2.8. But these two libraries are not compatible. I get the following error caused by this line in flash attention. It calls get_global_rank that is not available in PyTorch 1.12 and only available in newer PyTorch. What library version should I use to reproduce the results?

p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group AttributeError: module 'torch.distributed' has no attribute 'get_global_rank'

Plus, the scripts use a weight called "full.pt". It is not in OPT's huggingface repo. How should I get this file?