jzhang38 / TinyLlama

The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.
Apache License 2.0
7.31k stars 426 forks source link

Can anyone pre train tinyllama. py on v100s? #99

Closed JerryDaHeLian closed 6 months ago

JerryDaHeLian commented 7 months ago

When I Pre-train LLaMA, there is a error: The stacktrace: File "pretrain/tinyllama.py", line 533, in CLI(setup) File "/usr/local/lib/python3.8/dist-packages/jsonargparse/_cli.py", line 96, in CLI return _run_component(components, cfg_init) File "/usr/local/lib/python3.8/dist-packages/jsonargparse/_cli.py", line 181, in _run_component return component(cfg) File "pretrain/tinyllama.py", line 138, in setup main(fabric, train_data_dir, val_data_dir, resume) File "pretrain/tinyllama.py", line 211, in main train(fabric, state, train_dataloader, val_dataloader, monitor, resume) File "pretrain/tinyllama.py", line 224, in train validate(fabric, model, val_dataloader) File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "pretrain/tinyllama.py", line 410, in validate logits = model(input_ids) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 121, in forward output = self._forward_module(*args, *kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward output = self._fsdp_wrapped_module(*args, *kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1528, in call_impl return forward_call(*args, kwargs) File "/home/xxx/TinyLlama/lit_gpt/model.py", line 107, in forward x, = block(x, (cos, sin), max_seq_length) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward output = self._fsdp_wrapped_module(*args, *kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, kwargs) File "/home/xxx/TinyLlama/lit_gpt/model.py", line 172, in forward h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl Traceback (most recent call last): File "pretrain/tinyllama.py", line 533, in return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1528, in _call_impl Traceback (most recent call last): File "pretrain/tinyllama.py", line 533, in Traceback (most recent call last): File "pretrain/tinyllama.py", line 533, in return forward_call(args, kwargs) File "/home/xxx/TinyLlama/lit_gpt/model.py", line 236, in forward q = apply_rotary_emb_func(q, cos, sin, False, True) File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 551, in apply CLI(setup) File "/usr/local/lib/python3.8/dist-packages/jsonargparse/_cli.py", line 96, in CLI return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/xxx/TinyLlama/lit_gpt/fused_rotary_embedding.py", line 39, in forward rotary_emb.apply_rotary( RuntimeError: Expected x1.dtype() == cos.dtype() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) return _run_component(components, cfg_init)

8 v100s GPUs(Tesla V100S-PCIE-32GB) included in the development environment。 Prior to this, I was able to successfully pre train using 8 A100-40G( NVIDIA A100-PCIE-40GB) cards。

Who can help me? 3q!

jzhang38 commented 7 months ago

flash_attn 2 only supports Ampere GPU. I believe that is one possible reason.