Describe the bug
When using rotary positional embeddings and BF16, a type mismatch is produced:
Traceback (most recent call last):
File "/home/hatef.4/neox/gpt-neox/train.py", line 35, in <module>
main()
File "/home/hatef.4/neox/gpt-neox/train.py", line 31, in main
pretrain(neox_args=neox_args)
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 296, in pretrain
iteration = train(
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1465, in train
loss_dict, skipped_iter = train_step(
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1277, in train_step
reduced_loss = train_step_pipe(
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1374, in train_step_pipe
loss = model.train_batch(data_iter=data_iterator)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 362, in train_batch
self._exec_schedule(sched)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 1345, in _exec_schedule
self._exec_instr(**cmd.kwargs)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 657, in _exec_forward_pass
outputs = super().forward(inputs)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/engine.py", line 1836, in forward
loss = self.module(*inputs, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 365, in forward
x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 989, in checkpoint
CheckpointFunction.apply(function, all_outputs, *args)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 566, in forward
outputs = run_function(*inputs_cuda)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 344, in exec_func
inputs = layer(inputs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1302, in forward
output, moe_loss = super().forward(hidden_states, attention_mask)
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1228, in forward
Traceback (most recent call last):
attention_output, attention_bias = self.attention(
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
File "/home/hatef.4/neox/gpt-neox/train.py", line 35, in <module>
main()
File "/home/hatef.4/neox/gpt-neox/train.py", line 31, in main
pretrain(neox_args=neox_args)
return forward_call(*args, **kwargs)
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 952, in forward
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 296, in pretrain
iteration = train(
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1465, in train
loss_dict, skipped_iter = train_step(
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1277, in train_step
reduced_loss = train_step_pipe(
File "/home/hatef.4/neox/gpt-neox/megatron/training.py", line 1374, in train_step_pipe
loss = model.train_batch(data_iter=data_iterator)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 362, in train_batch
self._exec_schedule(sched)
context_layer = self.attention(
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 617, in attention
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 1345, in _exec_schedule
self._exec_instr(**cmd.kwargs)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/engine.py", line 657, in _exec_forward_pass
outputs = super().forward(inputs)
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
RuntimeError: expected scalar type BFloat16 but found Float File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/engine.py", line 1836, in forward
loss = self.module(*inputs, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 365, in forward
x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 989, in checkpoint
CheckpointFunction.apply(function, all_outputs, *args)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 566, in forward
outputs = run_function(*inputs_cuda)
File "/home/hatef.4/neox/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 344, in exec_func
inputs = layer(inputs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1302, in forward
output, moe_loss = super().forward(hidden_states, attention_mask)
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 1228, in forward
attention_output, attention_bias = self.attention(
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/hatef.4/neox/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 952, in forward
context_layer = self.attention(
File "/home/hatef.4/neox/gpt-neox/megatron/model/transformer.py", line 617, in attention
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
RuntimeError: expected scalar type BFloat16 but found Float
attention_probs here is a Float type instead of BFloat16. I have tracked this down to rotary embeddings but I do not know if this is an issue with all positional embeddings. The issue
To Reproduce
Steps to reproduce the behavior:
Take 1-3B.yml. Delete the "fp16" config arg and add:
I have tracked this down to rotary embeddings but I do not know if this is an Issue with all positional embeddings. The issue is not produced when pos_emb: "none"
Turns out this was just a config issue. You need to specify "precision":"bfloat16" when using bf16, https://github.com/EleutherAI/gpt-neox/pull/1311 should ensure this won't confuse anyone else.
Describe the bug When using rotary positional embeddings and BF16, a type mismatch is produced:
attention_probs
here is a Float type instead of BFloat16. I have tracked this down to rotary embeddings but I do not know if this is an issue with all positional embeddings. The issue To Reproduce Steps to reproduce the behavior:Take 1-3B.yml. Delete the
"fp16"
config arg and add:I have tracked this down to rotary embeddings but I do not know if this is an Issue with all positional embeddings. The issue is not produced when
pos_emb: "none"