CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.5k stars 472 forks source link

ILQL training batch2 tensor dimensions error #540

Open GenVr opened 1 year ago

GenVr commented 1 year ago

Hi, I'm trying an ILQL training with a gpt-j network trained with this code. I don't have this problem with the original pre-trained net, nor with a flan-xl.

Traceback (most recent call last):
  File "/home/jupyter/trlx/examples/summarize_rlhf/ilql_gptj.py", line 118, in <module>
    main()
  File "/home/jupyter/trlx/examples/summarize_rlhf/ilql_gptj.py", line 109, in main
    trlx.train(
  File "/home/jupyter/trlx/trlx/trlx.py", line 126, in train
    trainer.learn()
  File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 539, in learn
    results = self.evaluate()
  File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 384, in evaluate
    samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"])
  File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 276, in generate_eval
    return self.accelerator.unwrap_model(self.model).generate(
  File "/home/jupyter/trlx/trlx/models/modeling_ilql.py", line 307, in generate
    out = self.forward(
  File "/home/jupyter/trlx/trlx/models/modeling_ilql.py", line 263, in forward
    outputs = self.base_model(**forward_kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 854, in forward
    transformer_outputs = self.transformer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 689, in forward
    outputs = block(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 309, in forward
    attn_outputs = self.attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 257, in forward
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 183, in _attn
    attn_output = torch.matmul(attn_weights, value)
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [256, 101] but got: [256, 1].

This is my config:

config = TRLConfig(
    train=TrainConfig(
        seq_length=768,
        epochs=epochs,
        total_steps=total_steps,
        batch_size=batch_size,
        checkpoint_interval=eval_and_checkpoint,
        eval_interval=eval_and_checkpoint,
        pipeline="PromptPipeline",
        trainer="AccelerateILQLTrainer",
        save_best=True,
        checkpoint_dir="ckpts_ilql"
    ),
    model=ModelConfig(
        model_path=pretrained_model_path,
        num_layers_unfrozen=-1,
    ),
    tokenizer=TokenizerConfig(
        tokenizer_path="gpt2",
        truncation_side="right",
    ),
    optimizer=OptimizerConfig(
        name="adamw",
        kwargs={
            "lr": 5.0e-5,
            "betas": [0.9, 0.999],
            "eps": 1.0e-8,
            "weight_decay": 1.0e-6,
        },
    ),
    scheduler=SchedulerConfig(
        name="cosine_annealing",
        kwargs=dict(T_max=1e12, eta_min=5.0e-5)
    ),
    method=ILQLConfig(
        name="ILQLConfig",
        tau=0.7,
        gamma=0.99,
        cql_scale=0.1,
        awac_scale=1,
        alpha=0.001,
        beta=0,
        steps_for_target_q_sync=5,
        two_qs=True,
        gen_kwargs=dict(max_new_tokens=256, top_k=20, beta=4, temperature=1.0)
    ),
)

Thanks.

maxreciprocate commented 1 year ago

Hi @GenVr! Can you show your training code as well alongside your config? There might be an error in how you passed the training data in. Thanks!

GenVr commented 1 year ago

@maxreciprocate Regarding the dataset and train, I use this train() code:

trlx.train(
    samples = [(text,output) for text,output in zip(ttv_ds['train']['text'],ttv_ds['train']['output'])],
    rewards = labels,
    eval_prompts=ttv_ds['validation']['text'][:16],
    config = config,
)

Where:

samples = [(string, string), (string, string), ...] # list of tuples (string, string)
labels = [0,1,0,1...] # list of labels 0/1
samples = [string, string, ..] # list of strings

Thanks for your answer!