Calculate perplexity on a text with pre-trained language models. Support MLM (eg. DeBERTa), recurrent LM (eg. GPT3), and encoder-decoder LM (eg. Flan-T5).
MIT License
132
stars
11
forks
source link
fixed tensor device mismatch when using flash_attn #12
Fixed device mismatch between input tensors in PyTorch - fixed model_inputs to be on cuda:0 and not cpu.
Previous error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)
Fixed device mismatch between input tensors in PyTorch - fixed model_inputs to be on cuda:0 and not cpu.
Previous error: