Related to #41 , if I try to remedy the issue by changing the seq_length in the model config or manually slice the tensor size inside ppo_trainer ([1]) to match the input sizes, I do proceed to make_experience but then I get the following error:
Traceback (most recent call last):
File "~/mindrlhf/train.py", line 109, in <module>
run_rlhf(args)
File "~/mindrlhf/train.py", line 99, in run_rlhf
trainer.make_experience(num_rollouts=ppo_config.num_rollouts)
File "~/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 240, in make_experience
samples, resposne_array, left_padding_prompt = self.generate(prompt_tensors)
File "~/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 190, in generate
outputs = self.ppo_model.policy_model.model.generate(input_ids_list, max_length=self.ppo_config.seq_length)
File "~/venv/lib/python3.9/site-packages/mindformers/generation/text_generator.py", line 557, in generate
output_ids = self._forward(
File "~/venv/lib/python3.9/site-packages/mindformers/generation/text_generator.py", line 359, in _forward
res = self(**model_inputs) # pylint: disable=E1102
File "~/venv/lib/python3.9/site-packages/mindspore/nn/cell.py", line 680, in __call__
out = self.compile_and_run(*args, **kwargs)
File "~/venv/lib/python3.9/site-packages/mindspore/nn/cell.py", line 1020, in compile_and_run
self.compile(*args, **kwargs)
File "~/venv/lib/python3.9/site-packages/mindspore/nn/cell.py", line 997, in compile
_cell_graph_executor.compile(self, phase=self.phase,
File "~/venv/lib/python3.9/site-packages/mindspore/common/api.py", line 1547, in compile
result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
ValueError: For BatchMatMul, inputs shape cannot be broadcast on CPU/GPU, with x shape [const vector]{1, 32, 11000, 128}, y shape [const vector]{128, 128}
Hi,
Related to #41 , if I try to remedy the issue by changing the
seq_length
in the model config or manually slice the tensor size insideppo_trainer
([1]) to match the input sizes, I do proceed tomake_experience
but then I get the following error:Cna you please advise? Thanks.