sangmichaelxie / doremi

Pytorch implementation of DoReMi, a method for optimizing the data mixture weights in language modeling datasets
https://arxiv.org/abs/2305.10429
MIT License
286 stars 32 forks source link

AssertionError:assert q.dtype in [torch.float16, torch.bfloat16] #31

Closed Richard-Wth closed 1 month ago

Richard-Wth commented 1 month ago

When I finished the training process of the doremi reference model, I want to evaluate it on the downstream tasks, but I get this error: Traceback (most recent call last): File "/home/wth/My_codes/doremi/doremi/train.py", line 409, in fwd_output = self._forward( File "/home/wth/My_codes/doremi/doremi/models.py", line 59, in _forward hidden_states = self.transformer(input_ids, position_ids=position_ids, File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl main() File "/home/wth/My_codes/doremi/doremi/train.py", line 395, in main return forward_call(*args, kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 373, in forward downstream_metrics = trainer.evaluate_fewshot( File "/home/wth/My_codes/doremi/doremi/trainer.py", line 670, in evaluate_fewshot hidden_states, residual = layer(hidden_states, residual, File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl gen_tokens = model.generate( File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/utils/generation.py", line 166, in generate output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p, File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/utils/generation.py", line 115, in decode logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 148, in forward hidden_states = self.mixer(hidden_states, mixer_kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/wth/My_codes/doremi/doremi/models.py", line 99, in forward fwd_output = self._forward( File "/home/wth/My_codes/doremi/doremi/models.py", line 59, in _forward hidden_states = self.transformer(input_ids, position_ids=position_ids, File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 517, in forward return forward_call(*args, *kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 373, in forward context = self.inner_cross_attn(q, kv, causal=causal) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl hidden_states, residual = layer(hidden_states, residual, File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 124, in forward return forward_call(args, kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 148, in forward assert q.dtype in [torch.float16, torch.bfloat16] hidden_states = self.mixer(hidden_states, mixer_kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl AssertionError return forward_call(args, *kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 517, in forward context = self.inner_cross_attn(q, kv, causal=causal) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, **kwargs) File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 124, in forward assert q.dtype in [torch.float16, torch.bfloat16] AssertionError May I ask how this problem can be solved?

sangmichaelxie commented 1 month ago

I haven't been able to reproduce this, but it means that either the model or the data you're feeding in is the wrong type and should be cast to bfloat16.

Richard-Wth commented 1 month ago

I wiped out the bug by modifying the CUDA version to 11.7.