microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.45k stars 4.12k forks source link

[BUG] Mixed-precision: fp16 will cast input_ids into torch.cuda.HalfTensor instead of Long or Int. #5701

Closed zhaoyang02 closed 3 months ago

zhaoyang02 commented 4 months ago

Yes, DeepSpeed fp16/ZeRO and other mixed-precision training schemes perform forward and backward passes in fp16. This is source of the memory savings advantage of mixed-precision training. I am not sure why fp32 inputs worked in the older DeepSpeed versions, that was probably due to a bug. I can look into this later if you want and I get some bandwidth :).

Also I just noticed that stage is set to 0 in the "zero_optimization" dictionary of your json config. I hope you are aware that this means ZeRO is disabled, and fp16 is enabled.

What will happen if the input is a LongTensor, such as input_ids of a transformer. I found that deepspeed will cast input_ids into half tensor, crashing model at the embedding layer (since it requires Long input).

Originally posted by @NickyMouseSG in https://github.com/microsoft/DeepSpeed/issues/550#issuecomment-1722239501

My task is trying to test DPO training of llama3-8b model on Bridges-2 Platform with 16 V100-32GB GPUs,which don't support bf16, so I set fp16: true to use mixed-precision. The code is based on alignment-handbook

While using deepspeed 0.12.2 with mixed-precision: fp16, the model input_ids will be turned to torch.float16, which should be Int or Long. The deepspeed3 config file:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 16
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

The traceback:

[v007:1]:Traceback (most recent call last):
[v007:1]:  File "/jet/home/yzhao15/yangzhao/handbook/scripts/run_dpo.py", line 261, in <module>
[v007:1]:    main()
[v007:1]:  File "/jet/home/yzhao15/yangzhao/handbook/scripts/run_dpo.py", line 214, in main
[v007:1]:    train_result = trainer.train(resume_from_checkpoint=checkpoint)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1932, in train
[v007:1]:    return inner_training_loop(
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
[v007:1]:    tr_loss_step = self.training_step(model, inputs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 3307, in training_step
[v007:1]:    loss = self.compute_loss(model, inputs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1082, in compute_loss
[v007:1]:    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1023, in get_batch_loss_metrics
[v007:1]:    ) = self.concatenated_forward(model, batch)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 986, in concatenated_forward
[v007:1]:    all_logits = model(
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[v007:1]:    return self._call_impl(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[v007:1]:    return forward_call(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[v007:1]:    ret_val = func(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1814, in forward
[v007:1]:    loss = self.module(*inputs, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[v007:1]:    return self._call_impl(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
[v007:1]:    result = forward_call(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1423, in forward
[v007:1]:    transformer_outputs = self.transformer(
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[v007:1]:    return self._call_impl(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
[v007:1]:    result = forward_call(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1135, in forward
[v007:1]:    inputs_embeds = self.wte(input_ids)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[v007:1]:    return self._call_impl(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
[v007:1]:    result = forward_call(*args, **kwargs)
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
[v007:1]:    return F.embedding(
[v007:1]:  File "/jet/home/yzhao15/.conda/envs/handbook/lib/python3.10/site-packages/torch/nn/functional.py", line 2237, in embedding
[v007:1]:    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[v007:1]:RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.HalfTensor instead (while checking arguments for embedding

I‘ve added three print functions before the model gets input_ids: trl dpo_trainer.py:

print(f"concatenated_input_ids TYPE:{concatenated_batch['concatenated_input_ids'].dtype}")
        all_logits = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            **model_kwargs,
        ).logits

transformers modeling_llama.py:

print(f"LlamaModel input_ids TYPE:{input_ids.dtype}")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
print(f"LlamaForCausalLM input_ids TYPE:{input_ids.dtype}")
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

and I got

[v022:0]:concatenated_input_ids TYPE: torch.int64
[v022:0]:LlamaForCausalLM input_ids TYPE: torch.float16
[v022:0]:LlamaModel input_ids TYPE: torch.float16

When I add a 'input_ids=input_ids.long()' before the model gets it, I don't get the same input_ids type error. But instead I got a bug:

../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [16,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

I've tested several different versions of trl or transformers and I got the same issues. So I think the bug is caused by deepspeed.

Thanks!

lintao-common commented 4 months ago

There is a solution here, and the problem has been resolved. https://discuss.huggingface.co/t/getting-torch-cuda-halftensor-error-while-using-deepspeed-with-accelerate/39997/6

tjruwase commented 3 months ago

There is a solution here, and the problem has been resolved. https://discuss.huggingface.co/t/getting-torch-cuda-halftensor-error-while-using-deepspeed-with-accelerate/39997/6

@lintao-common, thanks for sharing solution. Closing this issue.