allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.13k stars 191 forks source link

Value is not broadcastable with batch_shape+event_shape #38

Open vcvcvnvcvcvn opened 1 year ago

vcvcvnvcvcvn commented 1 year ago

My yaml:

tokenizer:
  model_name: facebook/bart-large-cnn
  padding_side: left
  truncation_side: left
  pad_token_as_eos_token: False

reward_fn:
  id: rouge
  args:
    rouge_type: "rouge1"

datapool:
  id: cnn_daily_mail
  args:
    prompt_prefix: "Summarize: "
    max_size: 500

env:
  n_envs: 1
  args:
    max_prompt_length: 64
    max_episode_length: 100
    terminate_on_eos: True
    prompt_truncation_side: "right"
    context_start_token: 0

alg:
  id: ppo
  args: 
    n_steps: 512
    batch_size: 4
    verbose: 1
    learning_rate: 0.000002
    n_epochs: 5
    ent_coef: 0.0
  kl_div:
    coeff: 0.001
    target_kl: 0.2
  policy:
    id: seq2seq_lm_actor_critic_policy
    args:
      model_name: facebook/bart-large-cnn
      apply_model_parallel: False
      prompt_truncation_side: "right"
      generation_kwargs:
        do_sample: True
        top_k: 50
        min_length: 50
        max_new_tokens: 100          

train_evaluation:
  eval_batch_size: 16
  n_iters: 100
  eval_every: 10
  save_every: 1
  metrics:
    - id: meteor
      args: {}
    - id: rouge
    - id: bleu
      args: {}
    - id: bert_score
      args:
        language: en
    # - id: bleurt
    #   args:
    #     config_name: bleurt-large-512
    - id: diversity
      args: {}
    # - id: summaCZS
    #   args:
    #     granularity: sentence
    #     use_ent: True
    #     use_con: False
    # - id: summaCConv
    #   args:
    #     granularity: sentence
  generation_kwargs: 
    do_sample: True
    top_k: 0
    temperature: 0.7
    min_length: 50
    max_new_tokens: 100

My error:

[/content/RL4LMs/scripts/training/train_text_generation.py](https://localhost:8080/#) in main(config_path, project_name, experiment_name, base_path_to_store_results, entity_name, log_to_wandb)
     53             tracker=tracker,
     54         )
---> 55     trainer.train_and_eval()
     56 
     57 

[/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in train_and_eval(self)
    195         # evaluate on val and test set before fine-tuning once
    196         iter_start = self._trainer_state["current_iter"]
--> 197         self._evaluate_on_datapools(epoch=iter_start)
    198 
    199         # train for given number of iters

[/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in _evaluate_on_datapools(self, epoch, splits)
    181                                splits: List[str] = ["val", "test"]):
    182         for split in splits:
--> 183             evaluate_on_samples(policy=self._alg.policy,
    184                                 tokenizer=self._tokenizer,
    185                                 samples=self._samples_by_split[split],

[/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in evaluate_on_samples(policy, tokenizer, samples, batch_size, max_prompt_length, metrics, epoch, split_name, tracker, dt_control_token, gen_kwargs)
     39     n_samples = len(samples)
     40     for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"):
---> 41         batch_generated_texts = generate_text(
     42             policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs
     43         )

[/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in generate_text(policy, tokenizer, samples, max_prompt_length, dt_control_token, gen_kwargs)
    109         dt_control_token + sample.prompt_or_input_text for sample in samples
    110     ]
--> 111     generated_texts = policy.generate(
    112         tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs
    113     ).gen_texts

[/content/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py](https://localhost:8080/#) in generate(self, tokenizer, texts, max_prompt_length, input_ids, attention_mask, gen_kwargs)
    254             actions_at_step = gen_tokens[:, step]
    255             distribution = Categorical(logits=raw_logits)
--> 256             log_probs = distribution.log_prob(actions_at_step)
    257             step_wise_logprobs.append(log_probs)
    258             step_wise_actions.append(actions_at_step)

[/usr/local/lib/python3.8/dist-packages/torch/distributions/categorical.py](https://localhost:8080/#) in log_prob(self, value)
    121     def log_prob(self, value):
    122         if self._validate_args:
--> 123             self._validate_sample(value)
    124         value = value.long().unsqueeze(-1)
    125         value, log_pmf = torch.broadcast_tensors(value, self.logits)

[/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py](https://localhost:8080/#) in _validate_sample(self, value)
    280         for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
    281             if i != 1 and j != 1 and i != j:
--> 282                 raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
    283                                  format(actual_shape, expected_shape))
    284         try:

ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([400]).

By the way, can I use

datapool:
  id: cnn_daily_mail
  args:
    prompt_prefix: "Summarize: "
    max_size: 500

to control the data size I use?