fe1ixxu / ALMA

State-of-the-art LLM-based translation models.
MIT License
352 stars 26 forks source link

Training metrics currently not logged? #37

Open SirRob1997 opened 3 months ago

SirRob1997 commented 3 months ago

I'm trying to replicate the training process, is it expected that there are no metrics logged during eval such as automatic translation quality metrics e.g. using sacrebleu?

Seems like https://github.com/fe1ixxu/ALMA/blob/b92304c6548b7d0af0cdadca9d63c07c70d19cd5/run_llmmt.py#L152C5-L152C11 is never used. There should probably be a compute_metrics function that is passed to the trainer?

fe1ixxu commented 3 months ago

Thanks for your interest!

Yes, metric will not be logged during eval and only eval loss will show if you enable --do_eval. metric = evaluate.load("sacrebleu") is an leftover line and will be deleted in the future.

SirRob1997 commented 3 months ago

@fe1ixxu given the diverging behaviour between chrF / COMET it might be worth to track both during validation. How would you integrate that into the LlmmtTrainer ? There seems to be custom code for prediction_step if we disable --prediction_loss_only but I've ran into some issues with that (likely because it's never used) during the prediction_step when calling self.model.generate(**inputs, **gen_kwargs).

*** RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Likely, something along these lines should be passed to the trainer to track chrF:

def compute_metrics_chrf_comet(eval_preds):
        def postprocess_text(preds, labels):
            preds = [pred.strip() for pred in preds]
            labels = [[label.strip()] for label in labels]

            return preds, labels

        chrf_metric = evaluate.load("chrf")
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]

        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

        logger.info(f"Sample decoded prediction {decoded_preds[0]}")
        logger.info(f"Sample decoded label {decoded_labels[0]}")

        result_chrf = chrf_metric.compute(predictions=decoded_preds, references=decoded_labels)
        result = {"chrF": result_chrf["score"]}
        return result