huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.78k stars 26.24k forks source link

Tutorial for using DeepSpeed's activation checkpointing instead of PyTorch's #32409

Open huyiwen opened 1 month ago

huyiwen commented 1 month ago

Feature request

Is there a tutorial for using DeepSpeed's activation checkpointing instead of PyTorch's?

I'm using Trainer with ZeRO integration to train my model. Here's my code:

if training_args.deepspeed_gradient_checkpointing and training_args.deepspeed:
        from deepspeed.runtime.activation_checkpointing.checkpointing import configure
        configure(mpu_=None)
        from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
        model._set_gradient_checkpointing(training_args.deepspeed_gradient_checkpointing, checkpoint)
{
"activation_checkpointing": {
    "partition_activations": true,
    "cpu_checkpointing": true,
    "contiguous_memory_optimization": false,
    "number_checkpoints": null,
    "synchronize_checkpoint_boundary": false,
    "profile": false
  }
}
torchrun --nproc_per_node=8 \
    --nnodes=${NNODES} \
    --node_rank=${NODE_RANK} \
    --master_addr=${MASTER_ADDR} \
    --master_port=${MASTER_PORT} \
    train.py \
    --deepspeed ${DEEPSPEED_CONFIG_PATH} \
    --gradient_checkpointing False

However, I got this in FlashAttention2:

class XXXFlashAttention2(XXXAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        output_attentions = False

        bsz, q_len, _ = hidden_states.size()  # <---- this got error

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
  File "modeling_xxx.py", line 518, in forward
    bsz, q_len, _ = hidden_states.size()
ValueError: not enough values to unpack (expected 3, got 2)

Motivation

It seems there isn't such a tutorial available at the moment in either deepspeed's tutorial or huggingface.

Your contribution

Provide my results

huyiwen commented 1 month ago
Traceback (most recent call last):
  File "train.py", line 395, in <module>
    train()
  File "train.py", line 389, in train
    trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File ".../site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File ".../site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File ".../site-packages/transformers/trainer.py", line 3318, in training_step
    loss = self.compute_loss(model, inputs)
  File ".../site-packages/transformers/trainer.py", line 3363, in compute_loss
    outputs = model(**inputs)
  File ".../site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".../site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File ".../site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File ".../site-packages/deepspeed/runtime/engine.py", line 1822, in forward
    loss = self.module(*inputs, **kwargs)
  File ".../site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".../site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "model/modeling_miniyulan.py", line 1255, in forward
    outputs = self.model(
  File ".../site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".../site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "model/modeling_miniyulan.py", line 1057, in forward
    layer_outputs = self._gradient_checkpointing_func(
  File ".../site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 985, in checkpoint
    CheckpointFunction.apply(function, all_outputs, *args)
  File ".../site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File ".../site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 562, in forward
    outputs = run_function(*inputs_cuda)
  File ".../site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".../site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "model/modeling_miniyulan.py", line 801, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File ".../site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".../site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "model/modeling_miniyulan.py", line 517, in forward
    bsz, q_len, _ = hidden_states.size()
ValueError: not enough values to unpack (expected 3, got 2)
ArthurZucker commented 3 weeks ago

cc @muellerzr and @SunMarc