mindspore-lab / mindrlhf

Apache License 2.0
26 stars 12 forks source link

inputs shape problem during make_experience for llama,pangu, baichuan #42

Open kfertakis opened 10 months ago

kfertakis commented 10 months ago

Hi,

Related to #41 , if I try to remedy the issue by changing the seq_length in the model config or manually slice the tensor size inside ppo_trainer ([1]) to match the input sizes, I do proceed to make_experience but then I get the following error:

Traceback (most recent call last):
  File "~/mindrlhf/train.py", line 109, in <module>
    run_rlhf(args)
  File "~/mindrlhf/train.py", line 99, in run_rlhf
    trainer.make_experience(num_rollouts=ppo_config.num_rollouts)
  File "~/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 240, in make_experience
    samples, resposne_array, left_padding_prompt = self.generate(prompt_tensors)
  File "~/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 190, in generate
    outputs = self.ppo_model.policy_model.model.generate(input_ids_list, max_length=self.ppo_config.seq_length)
  File "~/venv/lib/python3.9/site-packages/mindformers/generation/text_generator.py", line 557, in generate
    output_ids = self._forward(
  File "~/venv/lib/python3.9/site-packages/mindformers/generation/text_generator.py", line 359, in _forward
    res = self(**model_inputs)  # pylint: disable=E1102
  File "~/venv/lib/python3.9/site-packages/mindspore/nn/cell.py", line 680, in __call__
    out = self.compile_and_run(*args, **kwargs)
  File "~/venv/lib/python3.9/site-packages/mindspore/nn/cell.py", line 1020, in compile_and_run
    self.compile(*args, **kwargs)
  File "~/venv/lib/python3.9/site-packages/mindspore/nn/cell.py", line 997, in compile
    _cell_graph_executor.compile(self, phase=self.phase,
  File "~/venv/lib/python3.9/site-packages/mindspore/common/api.py", line 1547, in compile
    result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
ValueError: For BatchMatMul, inputs shape cannot be broadcast on CPU/GPU, with x shape [const vector]{1, 32, 11000, 128}, y shape [const vector]{128, 128}

Cna you please advise? Thanks.