huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.28k stars 1.16k forks source link

Last layer of the Llava-1.5 visual tower is not training #1936

Open qgallouedec opened 1 month ago

qgallouedec commented 1 month ago
import numpy as np
import torch
from datasets import Dataset, features
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor

from trl import DPOConfig, DPOTrainer

model_id = "trl-internal-testing/tiny-random-llava-1.5"
model = AutoModelForVision2Seq.from_pretrained(model_id)
ref_model = AutoModelForVision2Seq.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

dataset = Dataset.from_dict(
    {
        "prompt": [
            "USER: <image>\nDescribe the image in great detail. ",
            "USER: <image>\nIs this bus in the USA? ",
            "USER: <image>\nGive a thorough description of the image. ",
            "USER: <image>\nWho are the people in the image? ",
            "USER: <image>\nWhat is written? ",
        ],
        "chosen": [
            "ASSISTANT: The image features a modern, multi-colored train. ",
            "ASSISTANT: Yes, it can be assumed that this bus is in the USA. ",
            "ASSISTANT: The image features a forest path. ",
            "ASSISTANT: There are two individuals, possibly girls or women. ",
            'ASSISTANT: "ccpb". ',
        ],
        "rejected": [
            "ASSISTANT: The image features a modern, colorful train. ",
            "ASSISTANT: No, it's not in the USA. ",
            "ASSISTANT: The image features a forest path surrounded by trees. ",
            "ASSISTANT: In the image, there are two individuals. ",
            'ASSISTANT: "ccpb". ',
        ],
        "images": [
            [Image.fromarray(np.random.randint(0, 255, (92, 33, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (64, 48, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (80, 152, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (57, 24, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (102, 48, 3), dtype=np.uint8))],
        ],
    }
)
dataset = dataset.cast_column("images", features.Sequence(features.Image()))

training_args = DPOConfig(
    output_dir="tmp_dir",
    per_device_train_batch_size=2,
    max_length=512,
    max_prompt_length=128,
    remove_unused_columns=False,
    report_to="none",
)
trainer = DPOTrainer(model=model, ref_model=ref_model, args=training_args, tokenizer=processor, train_dataset=dataset)

# Save the initial weights, so we can check if they have changed after training
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

# Check that the trainable params have changed
for n, param in previous_trainable_params.items():
    new_param = trainer.model.get_parameter(n)
    if param.sum() != 0:  # ignore 0 biases
        if torch.allclose(param, new_param, rtol=1e-12, atol=1e-12):
            print(f"Parameter {n} has not changed after training")
[2024-08-16 14:26:16,425] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:47: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  @autocast_custom_fwd
/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:66: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  @autocast_custom_bwd
Casting the dataset: 100%|█████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 1830.93 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 28.39 examples/s]
  0%|                                                                                                       | 0/9 [00:00<?, ?it/s]Could not estimate the number of tokens of the input, floating-point operations will not be computed
{'train_runtime': 2.7064, 'train_samples_per_second': 5.542, 'train_steps_per_second': 3.325, 'train_loss': 0.6846402486165365, 'epoch': 3.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.32it/s]
Parameter vision_tower.vision_model.encoder.layers.3.self_attn.k_proj.weight has not changed after training
Parameter vision_tower.vision_model.encoder.layers.3.self_attn.v_proj.weight has not changed after training
Parameter vision_tower.vision_model.encoder.layers.3.self_attn.q_proj.weight has not changed after training
Parameter vision_tower.vision_model.encoder.layers.3.self_attn.out_proj.weight has not changed after training
Parameter vision_tower.vision_model.encoder.layers.3.layer_norm1.weight has not changed after training
Parameter vision_tower.vision_model.encoder.layers.3.mlp.fc1.weight has not changed after training
Parameter vision_tower.vision_model.encoder.layers.3.mlp.fc2.weight has not changed after training
Parameter vision_tower.vision_model.encoder.layers.3.layer_norm2.weight has not changed after training
Parameter vision_tower.vision_model.post_layernorm.weight has not changed after training

https://github.com/huggingface/trl/blob/0956dc17cccd1d6301f059ded301dbcbfaf99970/tests/test_dpo_trainer.py#L989-L995

RylanSchaeffer commented 2 weeks ago

I thought Llava blocks gradients for the vision tower? My recollection is that freezing the vision tower is standard today with VLMs