huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.69k stars 1.22k forks source link

[GKD] mismatch in tensors when stacking log probs #2215

Open nivibilla opened 9 hours ago

nivibilla commented 9 hours ago

System Info

Latest TRL from source, can't run TRL env rn as cluster is shut down but I'm installing everything from source.

If required will restart cluster and run.

Information

Tasks

Reproduction


def qlora_gkd_train():
    import datasets
    import torch
    import transformers

    from trl import (
        GKDConfig,
        GKDTrainer,
        LogCompletionsCallback,
    )
    from peft import LoraConfig, TaskType, prepare_model_for_kbit_training

    import json

    with open('/local_disk0/training_config.json') as f:
        training_config = json.load(f)

    # # testing memory usage for batch size
    training_config['max_steps'] = 10
    # training_config['per_device_train_batch_size'] = 32
    print(json.dumps(training_config, indent=4))

    print("loading tokenizer")
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        training_config['teacher_model_name_or_path'],
        padding_side="left",
        truncation_side="left",
    )
    tokenizer.pad_token = tokenizer.eos_token

    print("loading dataset")
    train_dataset = datasets.load_from_disk('/local_disk0/train')

    # Model    
    torch_dtype = torch.bfloat16
    quant_storage_dtype = torch.bfloat16

    quantization_config = transformers.BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_quant_storage=quant_storage_dtype,
        )

    print("loading teacher model")
    teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
        training_config['student_model_name_or_path'],
        quantization_config=quantization_config,
        attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2"
        torch_dtype=quant_storage_dtype,
        device_map = "auto"
    )

    teacher_model = prepare_model_for_kbit_training(teacher_model)

    print("create student config")
    student_model_kwargs = dict(
        trust_remote_code=True,
        attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2"
        torch_dtype=quant_storage_dtype,
        use_cache=training_config['gradient_checkpointing'],
        device_map="auto",
        # quantization_config=quantization_config,
    )

    print("create student config")
    student_model_kwargs = dict(
        trust_remote_code=True,
        attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2"
        torch_dtype=quant_storage_dtype,
        use_cache=training_config['gradient_checkpointing'],
        device_map="auto",
        # quantization_config=quantization_config,
    )

    lora_config = LoraConfig(
        r=training_config['lora_r'],
        # target_modules="all-linear",
        target_modules=["q_proj", "k_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        lora_alpha=training_config['lora_alpha'],
        lora_dropout=0.05
    )

    training_arguments = GKDConfig(
        model_init_kwargs = student_model_kwargs,
        save_strategy='epoch',
        report_to='mlflow',
        # save_steps=training_config['save_steps'],
        ddp_find_unused_parameters=False,
        gradient_checkpointing=training_config['gradient_checkpointing'],
        per_device_train_batch_size=training_config['per_device_train_batch_size'],
        gradient_accumulation_steps=training_config['gradient_accumulation_steps'],
        num_train_epochs=training_config['num_train_epochs'],
        learning_rate=training_config['learning_rate'],
        warmup_ratio=training_config['warmup_ratio'],
        lr_scheduler_type="cosine",
        bf16=True,
        max_steps=training_config['max_steps'],
        logging_steps=training_config['logging_steps'],
        output_dir=training_config['output_dir'],
        gradient_checkpointing_kwargs={'use_reentrant':False},
        max_seq_length=training_config['max_seq_len'],
        use_liger=training_config['use_liger'],
        # optim="paged_adamw_8bit",
        dataset_text_field='prompt',
        packing=False,
        # # gkd params
        temperature=0.9,
        max_new_tokens=1024,
    )

    print("start training")
    trainer = GKDTrainer(
        model=training_config['student_model_name_or_path'],
        teacher_model=teacher_model,
        args=training_arguments,
        train_dataset=train_dataset,
        processing_class=tokenizer,
        peft_config=lora_config,
    )

    if training_config['resume']:
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

os.environ['ACCELERATE_BYPASS_DEVICE_MAP'] = "true"

qlora_gkd_train()

For further details. Teacher is Qwen 2.5 72B instruct. Student is Qwen 2.5 3B instruct.

Training Config :

{
    "teacher_model_name_or_path": "/local_disk0/Qwen/Qwen2.5-72B-Instruct",
    "student_model_name_or_path": "/local_disk0/Qwen/Qwen2.5-3B-Instruct",
    "learning_rate": 1e-05,
    "per_device_train_batch_size": 4,
    "gradient_accumulation_steps": 1,
    "logging_steps": 1,
    "num_train_epochs": 15,
    "gradient_checkpointing": true,
    "use_peft": true,
    "lora_r": 64,
    "lora_alpha": 16,
    "max_seq_len": 1382,
    "use_liger": false,
    "warmup_ratio": 0.1,
    "resume": false,
    "max_steps": 10
}

Error Trace:

Traceback (most recent call last):
  File "/local_disk0/train.py", line 162, in <module>
    qlora_gkd_train()
  File "/local_disk0/train.py", line 158, in qlora_gkd_train
    trainer.train()
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-7579e674-61e9-403f-8884-27b859a93e4a/lib/python3.11/site-packages/transformers/trainer.py", line 2085, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-7579e674-61e9-403f-8884-27b859a93e4a/lib/python3.11/site-packages/transformers/trainer.py", line 2421, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-7579e674-61e9-403f-8884-27b859a93e4a/lib/python3.11/site-packages/trl/trainer/gkd_trainer.py", line 292, in training_step
    loss = super().training_step(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-7579e674-61e9-403f-8884-27b859a93e4a/lib/python3.11/site-packages/transformers/trainer.py", line 3524, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-7579e674-61e9-403f-8884-27b859a93e4a/lib/python3.11/site-packages/trl/trainer/gkd_trainer.py", line 239, in compute_loss
    loss = self.generalized_jsd_loss(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-7579e674-61e9-403f-8884-27b859a93e4a/lib/python3.11/site-packages/trl/trainer/gkd_trainer.py", line 190, in generalized_jsd_loss
    torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [4, 0, 151936] at entry 0 and [4, 0, 152064] at entry 1

I originally thought the difference might be due to the max seq len which is 1382 based on my dataset max. But the difference in dimensions reported 128.

Expected behavior

Both the tokenizer for qwen 72b and 3b have a max length of 131072 not sure where the 151k numbers are coming from

Since it's the same tokenizer I assume it should be possible to distill them right?

nivibilla commented 9 hours ago

Actually Im stupid. I figured it out while I was typing the issue. I should be looking at the vocab size not the tokenizer length.

https://huggingface.co/Qwen/Qwen2.5-3B-Instruct/blob/aa8e72537993ba99e69dfaafa59ed015b17504d1/config.json#L26

nivibilla commented 9 hours ago

Is it worth adding a check in the GKD trainer for this param so this error is more readable for others?

nivibilla commented 9 hours ago

Llama 3.1 70b and llama 3.2 1B seem to have the same vocab size I will test with that. It will probably work.