sillsdev / silnlp

A set of pipelines for performing experiments on various NLP tasks with a focus on resource-poor/minority languages.
Other
35 stars 3 forks source link

Load HF model using torch.float16 #292

Closed ddaspit closed 10 months ago

ddaspit commented 10 months ago

We fine tune HF models using mixed precision (fp16). In spite of this, I believe that models like NLLB are still loaded using torch.float32 weights. We should try forcing the model to load using torch.float16 to see if it reduces memory usage and increases inferencing speed. The dtype can be specified when the model is loaded.

isaac091 commented 10 months ago

I'm getting an error when I try to do this: ValueError: Attempting to unscale FP16 gradients.. From what I've been reading, you can't train a model loaded with torch_dtype=torch.float16 because the optimization step still requires float32s, and just using mixed precision is supposed to take care of everything, so I'm not sure this is something we can do.

ddaspit commented 10 months ago

When are you loading the model using torch_dtype=torch.float16? Before training or inferencing? Try using it only before inferencing.