fe1ixxu / ALMA

State-of-the-art LLM-based translation models.
MIT License
440 stars 35 forks source link

Issues with Translation Quality Using ALMA/ALMA-R Models on Multi-Domain Dataset #56

Closed cocaer closed 4 months ago

cocaer commented 4 months ago

Hi Haoran,

I've been experimenting with the ALMA and ALMA-R models to translate the multi-domain dataset available here. Unfortunately, I'm observing subpar results, particularly with the ALMA-R model, in terms of both BLEU and COMET scores.

Below is the code snippet I've been using:

def load_model():
    global model, tokenizer
    model_path = "/data4/jibj/model/ALMA-13B-R"
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
    print(model)

def translate(text, srclang, tgtlang):
    prompt = f"Translate this from {srclang} to {tgtlang}:\n{srclang}: {text}\n{tgtlang}:"
    print(prompt)
    input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=200, truncation=True).input_ids.cuda()
    with torch.no_grad():
        tmp = model.generate(input_ids=input_ids, num_beams=5, max_new_tokens=200, do_sample=True, temperature=0.6, top_p=0.9)
        generated_ids = tmp[:, input_ids.size(1):]
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return outputs[0]

Could there be any incorrect operations in my approach, or is this typical performance given that the ALMA models were trained primarily on the WMT and Flores datasets? Any insights or recommendations would be greatly appreciated.

Baijun

fe1ixxu commented 4 months ago

Hi,

Thank you for your interest! ALMA should be generalized across all domains. If you are unsure about the correct usage of ALMA, a recommended approach is to run the following command:

accelerate launch --config_file configs/deepspeed_eval_config_bf16.yaml \
    run_llmmt.py \
    --model_name_or_path haoranxu/ALMA-13B-R \
    --text_test_file $YOUR_RAW_TEXT \
    --do_predict \
    --low_cpu_mem_usage \
    --language_pairs en-cs(your single direction) \
    --mmt_data_path ./human_written_data/ \
    --per_device_eval_batch_size 4 \
    --output_dir ./your_output_dir/ \
    --predict_with_generate \
    --max_new_tokens 512 \
    --max_source_length 512 \
    --bf16 \
    --seed 42 \
    --num_beams 5 \
    --overwrite_cache \
    --overwrite_output_dir

where text_test_file is a single raw test data file in text format (one line by one sentence you want to translate), this will override themmt_data_path.

It would also be helpful if you could share your detailed results to assist with debugging.

cocaer commented 4 months ago

I get the good results after following your instruction. It seems i used the broken model file. Thanks a lot!