huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.99k stars 1.26k forks source link

RuntimeError: chunk expects at least a 1-dimensional tensor #2338

Open imrankh46 opened 3 hours ago

imrankh46 commented 3 hours ago

System Info

Name: trl Version: 0.13.0.dev0 Name: transformers Version: 4.46.2 Python 3.11.10

Information

Tasks

Reproduction

SFTTrainer showing some error like this RuntimeError: chunk expects at least a 1-dimensional tensor.

model_id = 'Qwen/Qwen2.5-32B-Instruct'

model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2", torch_dtype= torch.bfloat16, quantization_config=bnb_config,)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# LoRA config
peft_config = LoraConfig(
    # init_lora_weights="pissa",
    r=32,
    lora_alpha=32,
    lora_dropout=0,
    bias="none",
    task_type="CAUSAL_LM",
    # target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj'],
    target_modules="all-linear",
    modules_to_save = ['lm_head', 'embed_token']

)
model = prepare_model_for_kbit_training(model)

def formatting_prompts_func(example):
    # Extracting data from the example, checking for None values
    instruction = example['instruction']
    input_data = example['input_original_text'] if example['input_original_text'] is not None else ""
    output_data = example['output_translation_correct_text'] if example['output_translation_correct_text'] is not None else ""

    # Constructing the messages
    messages = [
        {'role': 'system', 'content': instruction},
        {'role': 'user', 'content': input_data},
        {'role': 'assistant', 'content': output_data}
    ]

    # Applying the chat template using the provided tokenizer
    texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

    # # Handle NoneType result by assigning an empty string if texts is None
    # if texts is None:
    #     texts = ""

    return {"text": texts}

# Example usage on the Hugging Face dataset
formatted_e = formatting_prompts_func(train_set[5])
print(formatted_e['text'])

# Apply the function to the entire Hugging Face dataset
train_sft = train_set.map(formatting_prompts_func)
test_sft = test_set.map(formatting_prompts_func)

HAS_BFLOAT16 = torch.cuda.is_bf16_supported()
orpo_args = SFTConfig(
    # use_liger=True,
    max_seq_length = 1024,
    learning_rate=3e-4,
    weight_decay=0.01,
    warmup_steps=10,
    seed=0,
    lr_scheduler_type="linear",
    #max_steps = 1000,
    max_steps=-1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    fp16 = not HAS_BFLOAT16,
    bf16 = HAS_BFLOAT16,
    optim="adamw_8bit",
    # optim="galore_adamw",
    # optim_target_modules=optim_target_modules,
    num_train_epochs=5,
    eval_strategy="steps",
    # remove_unused_columns=False,
    # load_best_model_at_end = True,
    dataset_text_field="text",
    eval_steps=0.5,
    save_strategy = 'steps',
    logging_steps=1,
    #logging_steps=10,
    resume_from_checkpoint = True,
    report_to="wandb",
    output_dir="SFT-Qwen2.5-32b-lora-1",
    # packing=True
    # push_to_hub=True
)

import os
trainer = SFTTrainer(
    model=model,
    args=orpo_args,
    train_dataset=train_sft,
    eval_dataset=test_sft,
    #data_collator=collator,
    peft_config=peft_config,
    # dataset_text_field="text",
    # tokenizer=tokenizer,
    processing_class=tokenizer,
    # dataset_num_proc = os.cpu_count()
    # dataset_num_proc=6,

)
trainer.train()

error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[23], line 1
----> 1 trainer.train()

File /usr/local/lib/python3.11/dist-packages/transformers/trainer.py:2123, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2121         hf_hub_utils.enable_progress_bars()
   2122 else:
-> 2123     return inner_training_loop(
   2124         args=args,
   2125         resume_from_checkpoint=resume_from_checkpoint,
   2126         trial=trial,
   2127         ignore_keys_for_eval=ignore_keys_for_eval,
   2128     )

File /usr/local/lib/python3.11/dist-packages/transformers/trainer.py:2481, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2475 context = (
   2476     functools.partial(self.accelerator.no_sync, model=model)
   2477     if i == len(batch_samples) - 1
   2478     else contextlib.nullcontext
   2479 )
   2480 with context():
-> 2481     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2483 if (
   2484     args.logging_nan_inf_filter
   2485     and not is_torch_xla_available()
   2486     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2487 ):
   2488     # if loss is nan or inf simply add the average of previous logged losses
   2489     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /usr/local/lib/python3.11/dist-packages/transformers/trainer.py:3579, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3576     return loss_mb.reduce_mean().detach().to(self.args.device)
   3578 with self.compute_loss_context_manager():
-> 3579     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3581 del inputs
   3582 if (
   3583     self.args.torch_empty_cache_steps is not None
   3584     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3585 ):

File /usr/local/lib/python3.11/dist-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3631         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3632     inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
   3634 # Save past state if it exists
   3635 # TODO: this needs to be fixed and made cleaner later.
   3636 if self.args.past_index >= 0:

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/data_parallel.py:176, in DataParallel.forward(self, *inputs, **kwargs)
    171     if t.device != self.src_device_obj:
    172         raise RuntimeError("module must have its parameters and buffers "
    173                            f"on device {self.src_device_obj} (device_ids[0]) but found one of "
    174                            f"them on device: {t.device}")
--> 176 inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
    177 # for forward function without any inputs, empty list and dict will be created
    178 # so the module can be executed on one device which is the first one in device_ids
    179 if not inputs and not module_kwargs:

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/data_parallel.py:198, in DataParallel.scatter(self, inputs, kwargs, device_ids)
    192 def scatter(
    193     self,
    194     inputs: Tuple[Any, ...],
    195     kwargs: Optional[Dict[str, Any]],
    196     device_ids: Sequence[Union[int, torch.device]],
    197 ) -> Any:
--> 198     return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/scatter_gather.py:78, in scatter_kwargs(inputs, kwargs, target_gpus, dim)
     76 r"""Scatter with support for kwargs dictionary."""
     77 scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
---> 78 scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
     79 if len(scattered_inputs) < len(scattered_kwargs):
     80     scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs)))

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/scatter_gather.py:64, in scatter(inputs, target_gpus, dim)
     58 # After scatter_map is called, a scatter_map cell will exist. This cell
     59 # has a reference to the actual function scatter_map, which has references
     60 # to a closure that has a reference to the scatter_map cell (because the
     61 # fn is recursive). To avoid this reference cycle, we set the function to
     62 # None, clearing the cell
     63 try:
---> 64     res = scatter_map(inputs)
     65 finally:
     66     scatter_map = None  # type: ignore[assignment]

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/scatter_gather.py:55, in scatter.<locals>.scatter_map(obj)
     53     return [list(i) for i in zip(*map(scatter_map, obj))]
     54 if isinstance(obj, dict) and len(obj) > 0:
---> 55     return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
     56 return [obj for _ in target_gpus]

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/scatter_gather.py:51, in scatter.<locals>.scatter_map(obj)
     49     return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
     50 if isinstance(obj, tuple) and len(obj) > 0:
---> 51     return list(zip(*map(scatter_map, obj)))
     52 if isinstance(obj, list) and len(obj) > 0:
     53     return [list(i) for i in zip(*map(scatter_map, obj))]

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/scatter_gather.py:47, in scatter.<locals>.scatter_map(obj)
     45 def scatter_map(obj):
     46     if isinstance(obj, torch.Tensor):
---> 47         return Scatter.apply(target_gpus, None, dim, obj)
     48     if _is_namedtuple(obj):
     49         return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]

File /usr/local/lib/python3.11/dist-packages/torch/autograd/function.py:574, in Function.apply(cls, *args, **kwargs)
    571 if not torch._C._are_functorch_transforms_active():
    572     # See NOTE: [functorch vjp and autograd interaction]
    573     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 574     return super().apply(*args, **kwargs)  # type: ignore[misc]
    576 if not is_setup_ctx_defined:
    577     raise RuntimeError(
    578         "In order to use an autograd.Function with functorch transforms "
    579         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    580         "staticmethod. For more details, please see "
    581         "https://pytorch.org/docs/main/notes/extending.func.html"
    582     )

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/_functions.py:96, in Scatter.forward(ctx, target_gpus, chunk_sizes, dim, input)
     93 if torch.cuda.is_available() and ctx.input_device == -1:
     94     # Perform CPU to GPU copies in a background stream
     95     streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus]
---> 96 outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
     97 # Synchronize with the copy stream
     98 if streams is not None:

File /usr/local/lib/python3.11/dist-packages/torch/nn/parallel/comm.py:188, in scatter(tensor, devices, chunk_sizes, dim, streams, out)
    186 if out is None:
    187     devices = [_get_device_index(d) for d in devices]
--> 188     return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
    189 else:
    190     if devices is not None:

RuntimeError: chunk expects at least a 1-dimensional tensor

Expected behavior

RuntimeError: chunk expects at least a 1-dimensional tensor

imrankh46 commented 3 hours ago

@kashif any suggestions?