Caiyun-AI / DCFormer

MIT License
185 stars 15 forks source link

a bug in the HLO->LLVM IR lowering #5

Closed szrrr04 closed 3 months ago

szrrr04 commented 3 months ago

I tried to reproduce your training code using the JAX framework, and I changed the dataset to the lm1b dataset (so I modified the data processing function to _tfds_data_processing.py, which is already present in your folder). Without modifying any data types, I encountered thousands of "Floating-point arithmetic operators only work with floating-point types!" errors during the run. Additionally, it prompted "This probably indicates a bug in the HLO -> LLVM IR lowering. Rerun with --xla_dump_to to get the IR." Could you please help me understand the cause of this issue? image image I've omitted thousands of lines of errors between the two images I've uploaded to you. The original training log file might be unable to upload due to formatting issues. I would be extremely grateful if you could help me resolve this issue!

szrrr04 commented 3 months ago

I noticed there was also a stack trace error mixed in. Begin stack trace

_PyObject_MakeTpCall

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyObject_MakeTpCall

_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

PyEval_EvalCode

_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main

End stack trace

Per train step, total TFLOPs will be 207.49, split as 90.45% learnable weight flops and 9.55% attention flops jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/u2022212091/u2022212091/DCFormer/jax/MaxText/train.py", line 551, in app.run(main) File "/home/u2022212091/.conda/envs/20240816/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/u2022212091/.conda/envs/20240816/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/u2022212091/u2022212091/DCFormer/jax/MaxText/train.py", line 547, in main train_loop(config) File "/home/u2022212091/u2022212091/DCFormer/jax/MaxText/train.py", line 471, in train_loop state, metrics = p_train_step( jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/xla/xla/service/cpu/cpu_compiler.cc:1048) !llvm::verifyModule(llvm_module, &err_stream) Invalid LLVM IR before optimizations:

Lisennlp commented 3 months ago

Your problem seems to be a data type problem. You can check whether the type of the data you entered is correct. The correct one should be: image Secondly, check whether the jax you installed is compatible with your cuda version. We tested the version: image and jax and jaxlib version is: image

If the above methods still cannot solve your problem, you can try to use our dataset (https://huggingface.co/datasets/Caiyun-AI/Pile_tfrecord/tree/main) to see if it can run. Since the entire dataset is relatively large, you can use only one file for the experiment. The following is a log of an experiment I just ran according to the readme: image

I hope the above can help you.

szrrr04 commented 3 months ago

Thank you! I checked, and first of all, the type of the data I entered is correct. I didn’t modify anything related to data types.

Secondly, all the package versions I installed are correct, exactly as specified in the requirements_gpu.txt. However, I found that my CUDA version might not be right because I’m using the school's supercomputing platform. I noticed that the current CUDA version on the platform is pointing to 11.8, but I don’t have the permission to change it, so I might need to contact the administrator later.

Without changing the CUDA version, I tried switching the dataset to the default dataset Pile, but I still encountered the same issue. This makes it almost certain that the problem is related to the CUDA version, rather than the dataset I modified itself.

I’d like to ask, since the current CUDA versions available on the supercomputing platform are 12.2 and 12.4, can I ask the administrator to change the current version to point to 12.2 instead of having them install version 12.1? Does the CUDA compiler version have to be exactly 12.1, or is any version above 12.1 acceptable?

I look forward to your reply and thank you in advance. image

Lisennlp commented 3 months ago

When I switched the jax version to jax[cuda12_pip]==0.4.25, I have reproduced your problem. After this problem occurred, I upgraded the jax version to jax[cuda12_pip]==0.4.30, and it can run successfully again. Therefore, as long as the jax and cuda versions are compatible, it is not only suitable for cuda12.1. Therefore, you can adapt to your cuda version by changing the jax version.

szrrr04 commented 3 months ago

Okay, thank you, I understand. So, you're saying I can keep my CUDA version at 11.8 and just adjust my JAX version, right?For example, I could try jax[cuda11_pip]==0.4.30 on top of CUDA 11.8, right? (I'm currently using jax[cuda11_pip]==0.4.25).

Lisennlp commented 3 months ago

Yes, besides that, you can also try jax[cuda12_pip]==0.4.30 on cuda 12.2 and 12.4.

szrrr04 commented 3 months ago

Thank you very much! I tried upgrading JAX to 0.4.30 while keeping CUDA at 11.8 (since contacting the administrator is a bit inconvenient), and that solved the problem that had been troubling me for a long time!

Lisennlp commented 3 months ago

Hello, your error does not seem to be out of HBM, and your 120G HBM is enough to train a 6B model with length = 2048. I suspect that the jax version you installed is still wrong, and it seems that the gpu cannot be obtained. You can try running python -c 'import jax; print(len(jax.devices()))' in the command line to see if jax can obtain the gpu devices. If the output gpu number is 0, it means that it cannot be obtained.

Lisennlp commented 3 months ago

If it is due to the cuda and jax version, you can try to install jax[cuda12_pip]==0.4.30. And set your cuda version to cuda12.2 or 12.4. You do not need to reinstall cuda. ​​If you do not know how to set it, you can refer to: https://www.cnblogs.com/huadongw/p/15193247.html

szrrr04 commented 3 months ago

Thank you for your help! I successfully set up CUDA 12.2 following the link you provided. However, when I tried to install jax[cuda12_pip]==0.4.30, I found that it said it couldn't find pip. Is this because I modified the PATH? How can I resolve this issue? image image

szrrr04 commented 3 months ago

I think this problem has already been solved by me. I added my PIP_OR_PYTHON_PATH to the PATH, and now my pip can be found! Is this the correct solution? image

Lisennlp commented 3 months ago

I think this problem has already been solved by me. I added my PIP_OR_PYTHON_PATH to the PATH, and now my pip can be found! Is this the correct solution? image

Yes, you are right!

szrrr04 commented 3 months ago

Hello, I think I’m ready now! I have set the CUDA environment variables in the ~/.bashrc file, and it’s for CUDA 12.2(asP1). Additionally, I’ve successfully installed jax[cuda12pip]==0.4.30(asP2). I should be all set now, right? I have one last question: yesterday, when I hadn’t figured out how to set up CUDA 12.2, I used the conda install nvidia/label/cuda-12.1.1::cuda-tools command in my virtual environment, which installed a bunch of packages that start with cuda. (asP3)Is this a problem? Should I remove them, or can I just leave them as they are? image image image

Lisennlp commented 3 months ago

All are fine. The extra ones just take up more memory. If you don't need them, you can selectively uninstall them.

szrrr04 commented 3 months ago

Finally, it's all working!! Everything went smoothly. Thank you so, so much! image And I think I finally understand the specific reason behind my previous errors and why it seemed to work yesterday but didn’t use the GPU. Earlier, when I encountered data type errors, I hadn’t set the CUDA environment variables in my bashrc file, but my JAX and JAXlib versions were based on CUDA 11, which caused the errors. Last night, I upgraded my JAX and JAXlib to JAX[cuda11_pip]==0.4.30, but during installation, it warned me that there was no CUDA version, so I ended up installing the non-CUDA versions of JAX and JAXlib. Since I didn’t have a CUDA compiler configured, the code ran but couldn’t utilize the GPU. Now that I’ve correctly configured everything, it’s all working fine!

Thank you so much again! Without your help, I would have struggled for much longer! Best wishes~

szrrr04 commented 3 months ago

However, I noticed that on the current supercomputing platform at my school, I used 16 CPU cores and 2 GPUs, but the steps/s is only 0.361. I calculated that if I were to train for the 10,000,000 steps specified in the configuration file, it would still take a very long time. Can this 10,000,000 steps be shortened appropriately, or will my code automatically stop training once the metrics stabilize? And one more question: If my training is interrupted due to a timeout, the code will automatically load the checkpoint files the next time I resume training, correct? I don’t need to manually configure anything, right? Thank you very much~

Lisennlp commented 3 months ago

The 10 milions step is not the total number of steps you are training, just to keep it going, so you need to stop the program manually, for example when loss drops to a certain value. In addition, if the program is unexpectedly interrupted, the latest model will automatically load after you restart, and the data will be read from the last interrupt.You don’t need to manually configure anything

szrrr04 commented 3 months ago

Thank you!I got it!

发自我的iPhone

------------------ Original ------------------ From: Lisennlp @.> Date: Tue,Aug 20,2024 7:12 PM To: Caiyun-AI/DCFormer @.> Cc: Zerah @.>, Author @.> Subject: Re: [Caiyun-AI/DCFormer] a bug in the HLO->LLVM IR lowering (Issue#5)

The 10 milions step is not the total number of steps you are training, just to keep it going, so you need to stop the program manually, for example when loss drops to a certain value. In addition, if the program is unexpectedly interrupted, the latest model will automatically load after you restart, and the data will be read from the last interrupt.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>