facebookresearch / SONAR

SONAR, a new multilingual and multimodal fixed-size sentence embedding space, with a full suite of speech and text encoders and decoders.
Other
341 stars 34 forks source link

Training on lower precision #25

Closed King-Rafat closed 5 months ago

King-Rafat commented 5 months ago

Hi, great work done here! Have you tried training or inferring the models at a lower precision? What is the performance loss for that?

avidale commented 5 months ago

Hi @King-Rafat! Based on my (rather limited) experiments, training SONAR models at half precision (float16) can be sometimes unstable when computing cross-entropy loss for the decoder. So I would probably recommend float32 or some mixed precision for training.

However, the inference of SONAR text models in float16 seems to be totally fine. The code snippet below illustrated how SONAR translation quality isn't affected by quantization.

Example

```python import datasets import torch from sacrebleu import BLEU from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline # loading the models to GPU (by default, they are in float32 precision) device = torch.device("cuda") t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=device) vec2text_model = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_encoder", device=device) # setting the test dataset src_lang, tgt_lang = "eng_Latn", "fra_Latn" lang2flores = { lang: datasets.load_dataset("facebook/flores", lang, trust_remote_code=True) for lang in [src_lang, tgt_lang] } source = lang2flores[src_lang]['dev']['sentence'] target = lang2flores[tgt_lang]['dev']['sentence'] # computing the embeddings in two precisions embs_32 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True) t2vec_model.half(); embs_16 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True) # translating each embeddings matrix into French in two precisions pred_32x32 = vec2text_model.predict(embs_32, target_lang=tgt_lang, batch_size=32, progress_bar=True) pred_16x32 = vec2text_model.predict(embs_16.to(torch.float32), target_lang=tgt_lang, batch_size=32, progress_bar=True) vec2text_model.half(); pred_32x16 = vec2text_model.predict(embs_32.to(torch.float16), target_lang=tgt_lang, batch_size=32, progress_bar=True) pred_16x16 = vec2text_model.predict(embs_16.to(torch.float16), target_lang=tgt_lang, batch_size=32, progress_bar=True) # evaluating the quality (higher BLEU <=> better) bleu_calc = BLEU() print(bleu_calc.corpus_score(pred_32x32, [target]).score) # 45.35502456250957 print(bleu_calc.corpus_score(pred_16x32, [target]).score) # 45.41064939316419 print(bleu_calc.corpus_score(pred_32x16, [target]).score) # 45.385803594567314 print(bleu_calc.corpus_score(pred_16x16, [target]).score) # 45.42170584536023 ```

Also, there is evidence that NLLB models can be quantized (e.g. with ctranslate2) even to int8 representations without serious performance degradations. And SONAR text models are essentially a fine-tuned NLLB model (but with a fixed-size representation bottleneck), so I would expect them to be quantizable to int8 as well.

King-Rafat commented 5 months ago

Hi @avidale, thank you for your feedback!