EleutherAI / gpt-neox

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries
https://www.eleuther.ai/
Apache License 2.0
6.96k stars 1.02k forks source link

Error with rotary embeddings and BFloat16 #1305

Closed jahatef closed 6 days ago

jahatef commented 1 month ago

Describe the bug When using rotary positional embeddings and BF16, a type mismatch is produced:

Traceback (most recent call last):
  File "/home/hatef.4/neox/gpt-neox/train.py", line 35, in <module>
    main()
  File "/home/hatef.4/neox/gpt-neox/train.py", line 31, in main
    pretrain(neox_args=neox_args)
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 296, in pretrain
    iteration = train(
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1465, in train
    loss_dict, skipped_iter = train_step(
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1277, in train_step
    reduced_loss = train_step_pipe(
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1374, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 362, in train_batch
    self._exec_schedule(sched)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 1345, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 657, in _exec_forward_pass
    outputs = super().forward(inputs)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/engine.py", line 1836, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 365, in forward
    x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 989, in checkpoint
    CheckpointFunction.apply(function, all_outputs, *args)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 566, in forward
    outputs = run_function(*inputs_cuda)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 344, in exec_func
    inputs = layer(inputs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1302, in forward
    output, moe_loss = super().forward(hidden_states, attention_mask)
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1228, in forward
Traceback (most recent call last):
    attention_output, attention_bias = self.attention(
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  File "/home/hatef.4/neox/gpt-neox/train.py", line 35, in <module>
    main()
      File "/home/hatef.4/neox/gpt-neox/train.py", line 31, in main
    pretrain(neox_args=neox_args)
return forward_call(*args, **kwargs)
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 952, in forward
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 296, in pretrain
    iteration = train(
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1465, in train
    loss_dict, skipped_iter = train_step(
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1277, in train_step
    reduced_loss = train_step_pipe(
  File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1374, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
      File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 362, in train_batch
    self._exec_schedule(sched)
context_layer = self.attention(
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 617, in attention
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 1345, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 657, in _exec_forward_pass
    outputs = super().forward(inputs)
    context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
RuntimeError: expected scalar type BFloat16 but found Float  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/engine.py", line 1836, in forward
    loss = self.module(*inputs, **kwargs)

  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 365, in forward
    x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 989, in checkpoint
    CheckpointFunction.apply(function, all_outputs, *args)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 566, in forward
    outputs = run_function(*inputs_cuda)
  File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 344, in exec_func
    inputs = layer(inputs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1302, in forward
    output, moe_loss = super().forward(hidden_states, attention_mask)
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1228, in forward
    attention_output, attention_bias = self.attention(
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 952, in forward
    context_layer = self.attention(
  File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 617, in attention
    context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
RuntimeError: expected scalar type BFloat16 but found Float

attention_probs here is a Float type instead of BFloat16. I have tracked this down to rotary embeddings but I do not know if this is an issue with all positional embeddings. The issue To Reproduce Steps to reproduce the behavior:

Take 1-3B.yml. Delete the "fp16" config arg and add:

  "bf16": {
    "bf16": true,
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 12,
    "hysteresis": 2,
    "min_loss_scale": 1,
  },

I have tracked this down to rotary embeddings but I do not know if this is an Issue with all positional embeddings. The issue is not produced when pos_emb: "none"

AI-WAIFU commented 1 month ago

Turns out this was just a config issue. You need to specify "precision":"bfloat16" when using bf16, https://github.com/EleutherAI/gpt-neox/pull/1311 should ensure this won't confuse anyone else.