Closed King-Rafat closed 5 months ago
Hi @King-Rafat!
Based on my (rather limited) experiments, training SONAR models at half precision (float16
) can be sometimes unstable when computing cross-entropy loss for the decoder. So I would probably recommend float32
or some mixed precision for training.
However, the inference of SONAR text models in float16
seems to be totally fine.
The code snippet below illustrated how SONAR translation quality isn't affected by quantization.
```python import datasets import torch from sacrebleu import BLEU from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline # loading the models to GPU (by default, they are in float32 precision) device = torch.device("cuda") t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=device) vec2text_model = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_encoder", device=device) # setting the test dataset src_lang, tgt_lang = "eng_Latn", "fra_Latn" lang2flores = { lang: datasets.load_dataset("facebook/flores", lang, trust_remote_code=True) for lang in [src_lang, tgt_lang] } source = lang2flores[src_lang]['dev']['sentence'] target = lang2flores[tgt_lang]['dev']['sentence'] # computing the embeddings in two precisions embs_32 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True) t2vec_model.half(); embs_16 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True) # translating each embeddings matrix into French in two precisions pred_32x32 = vec2text_model.predict(embs_32, target_lang=tgt_lang, batch_size=32, progress_bar=True) pred_16x32 = vec2text_model.predict(embs_16.to(torch.float32), target_lang=tgt_lang, batch_size=32, progress_bar=True) vec2text_model.half(); pred_32x16 = vec2text_model.predict(embs_32.to(torch.float16), target_lang=tgt_lang, batch_size=32, progress_bar=True) pred_16x16 = vec2text_model.predict(embs_16.to(torch.float16), target_lang=tgt_lang, batch_size=32, progress_bar=True) # evaluating the quality (higher BLEU <=> better) bleu_calc = BLEU() print(bleu_calc.corpus_score(pred_32x32, [target]).score) # 45.35502456250957 print(bleu_calc.corpus_score(pred_16x32, [target]).score) # 45.41064939316419 print(bleu_calc.corpus_score(pred_32x16, [target]).score) # 45.385803594567314 print(bleu_calc.corpus_score(pred_16x16, [target]).score) # 45.42170584536023 ```
Also, there is evidence that NLLB models can be quantized (e.g. with ctranslate2) even to int8
representations without serious performance degradations. And SONAR text models are essentially a fine-tuned NLLB model (but with a fixed-size representation bottleneck), so I would expect them to be quantizable to int8
as well.
Hi @avidale, thank you for your feedback!
Hi, great work done here! Have you tried training or inferring the models at a lower precision? What is the performance loss for that?