Open SirRob1997 opened 3 months ago
Thanks for your interest!
Yes, metric will not be logged during eval and only eval loss will show if you enable --do_eval
. metric = evaluate.load("sacrebleu")
is an leftover line and will be deleted in the future.
@fe1ixxu given the diverging behaviour between chrF / COMET it might be worth to track both during validation. How would you integrate that into the LlmmtTrainer
? There seems to be custom code for prediction_step
if we disable --prediction_loss_only
but I've ran into some issues with that (likely because it's never used) during the prediction_step
when calling self.model.generate(**inputs, **gen_kwargs)
.
*** RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
Likely, something along these lines should be passed to the trainer to track chrF:
def compute_metrics_chrf_comet(eval_preds):
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [[label.strip()] for label in labels]
return preds, labels
chrf_metric = evaluate.load("chrf")
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
logger.info(f"Sample decoded prediction {decoded_preds[0]}")
logger.info(f"Sample decoded label {decoded_labels[0]}")
result_chrf = chrf_metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"chrF": result_chrf["score"]}
return result
I'm trying to replicate the training process, is it expected that there are no metrics logged during eval such as automatic translation quality metrics e.g. using
sacrebleu
?Seems like https://github.com/fe1ixxu/ALMA/blob/b92304c6548b7d0af0cdadca9d63c07c70d19cd5/run_llmmt.py#L152C5-L152C11 is never used. There should probably be a
compute_metrics
function that is passed to the trainer?