Open vcvcvnvcvcvn opened 1 year ago
My yaml:
tokenizer: model_name: facebook/bart-large-cnn padding_side: left truncation_side: left pad_token_as_eos_token: False reward_fn: id: rouge args: rouge_type: "rouge1" datapool: id: cnn_daily_mail args: prompt_prefix: "Summarize: " max_size: 500 env: n_envs: 1 args: max_prompt_length: 64 max_episode_length: 100 terminate_on_eos: True prompt_truncation_side: "right" context_start_token: 0 alg: id: ppo args: n_steps: 512 batch_size: 4 verbose: 1 learning_rate: 0.000002 n_epochs: 5 ent_coef: 0.0 kl_div: coeff: 0.001 target_kl: 0.2 policy: id: seq2seq_lm_actor_critic_policy args: model_name: facebook/bart-large-cnn apply_model_parallel: False prompt_truncation_side: "right" generation_kwargs: do_sample: True top_k: 50 min_length: 50 max_new_tokens: 100 train_evaluation: eval_batch_size: 16 n_iters: 100 eval_every: 10 save_every: 1 metrics: - id: meteor args: {} - id: rouge - id: bleu args: {} - id: bert_score args: language: en # - id: bleurt # args: # config_name: bleurt-large-512 - id: diversity args: {} # - id: summaCZS # args: # granularity: sentence # use_ent: True # use_con: False # - id: summaCConv # args: # granularity: sentence generation_kwargs: do_sample: True top_k: 0 temperature: 0.7 min_length: 50 max_new_tokens: 100
My error:
[/content/RL4LMs/scripts/training/train_text_generation.py](https://localhost:8080/#) in main(config_path, project_name, experiment_name, base_path_to_store_results, entity_name, log_to_wandb) 53 tracker=tracker, 54 ) ---> 55 trainer.train_and_eval() 56 57 [/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in train_and_eval(self) 195 # evaluate on val and test set before fine-tuning once 196 iter_start = self._trainer_state["current_iter"] --> 197 self._evaluate_on_datapools(epoch=iter_start) 198 199 # train for given number of iters [/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in _evaluate_on_datapools(self, epoch, splits) 181 splits: List[str] = ["val", "test"]): 182 for split in splits: --> 183 evaluate_on_samples(policy=self._alg.policy, 184 tokenizer=self._tokenizer, 185 samples=self._samples_by_split[split], [/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in evaluate_on_samples(policy, tokenizer, samples, batch_size, max_prompt_length, metrics, epoch, split_name, tracker, dt_control_token, gen_kwargs) 39 n_samples = len(samples) 40 for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"): ---> 41 batch_generated_texts = generate_text( 42 policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs 43 ) [/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in generate_text(policy, tokenizer, samples, max_prompt_length, dt_control_token, gen_kwargs) 109 dt_control_token + sample.prompt_or_input_text for sample in samples 110 ] --> 111 generated_texts = policy.generate( 112 tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs 113 ).gen_texts [/content/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py](https://localhost:8080/#) in generate(self, tokenizer, texts, max_prompt_length, input_ids, attention_mask, gen_kwargs) 254 actions_at_step = gen_tokens[:, step] 255 distribution = Categorical(logits=raw_logits) --> 256 log_probs = distribution.log_prob(actions_at_step) 257 step_wise_logprobs.append(log_probs) 258 step_wise_actions.append(actions_at_step) [/usr/local/lib/python3.8/dist-packages/torch/distributions/categorical.py](https://localhost:8080/#) in log_prob(self, value) 121 def log_prob(self, value): 122 if self._validate_args: --> 123 self._validate_sample(value) 124 value = value.long().unsqueeze(-1) 125 value, log_pmf = torch.broadcast_tensors(value, self.logits) [/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py](https://localhost:8080/#) in _validate_sample(self, value) 280 for i, j in zip(reversed(actual_shape), reversed(expected_shape)): 281 if i != 1 and j != 1 and i != j: --> 282 raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'. 283 format(actual_shape, expected_shape)) 284 try: ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([400]).
By the way, can I use
datapool: id: cnn_daily_mail args: prompt_prefix: "Summarize: " max_size: 500
to control the data size I use?
My yaml:
My error:
By the way, can I use
to control the data size I use?