asahi417 / lmppl

Calculate perplexity on a text with pre-trained language models. Support MLM (eg. DeBERTa), recurrent LM (eg. GPT3), and encoder-decoder LM (eg. Flan-T5).
MIT License
134 stars 11 forks source link

fixed tensor device mismatch when using flash_attn #12

Closed kaiyamclarke closed 1 month ago

kaiyamclarke commented 1 month ago

Fixed device mismatch between input tensors in PyTorch - fixed model_inputs to be on cuda:0 and not cpu.

Previous error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)