erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
168 stars 19 forks source link

Training on TPU pod #34

Closed JinSeoungwoo closed 8 months ago

JinSeoungwoo commented 8 months ago

Can I train using TPUv3-32 pod?

erfanzar commented 8 months ago

yes for sure, can you tell me which platform you using to train your model on so i can give you step by step example

JinSeoungwoo commented 8 months ago

yes for sure, can you tell me which platform you using to train your model on so i can give you step by step example

I'm using gcp right now with v2-alpha-pod, I want to finetune mistral with v3-32

erfanzar commented 8 months ago

actually the normal code for training should work just fine but you can use custom mesh size of 1,-1,1 or 2,-1,1

JinSeoungwoo commented 8 months ago

actually the normal code for training should work just fine but you can use custom mesh size of 1,-1,1 or 2,-1,1

I saw https://cloud.google.com/tpu/docs/jax-pods?hl=ko in the TPU POD documentation, can I do the same with EasyDel?

erfanzar commented 8 months ago

I guess EasyDel Automatically runs the code across all the available devices but yes you have to run that like this

gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

make sure to install the jax and jaxlib version >=0.4.10

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.4.10' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

I gueese there's a bug in the mistral model and i suggest you don't train your models till I fix that (10~18 hours)

JinSeoungwoo commented 8 months ago

I guess EasyDel Automatically runs the code across all the available devices but yes you have to run that like this

gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

make sure to install the jax and jaxlib version >=0.4.10

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.4.10' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

I gueese there's a bug in the mistral model and i suggest you don't train your models till I fix that (10~18 hours)

I've tried a lot, but after a certain number of steps, the accuracy drops dramatically. Is it possible that I'm setting the parameters incorrectly? Or is it related to a bug?

erfanzar commented 8 months ago

yes that's related to the bug

JinSeoungwoo commented 8 months ago

yes that's related to the bug

Oh.. I thought I was setting it up wrong until now. I'll have to do the training after the bug is fixed. Thanks for the fix!

erfanzar commented 8 months ago

fix <3

JinSeoungwoo commented 8 months ago

fix <3

Thank you for fixing error. But I got this error when training on tpu pod

Traceback (most recent call last):
  File "/root/tune.py", line 249, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/tune.py", line 239, in main
    output = trainer.train(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/fsdp_train.py", line 427, in train
    'perplexity': jnp.exp(loss).tolist(),
RuntimeError: Running operations on `Array`s that are not fully addressable by this process (i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager.
JinSeoungwoo commented 8 months ago

fix <3

Thank you for fixing error. But I got this error when training on tpu pod

Traceback (most recent call last):
  File "/root/tune.py", line 249, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/tune.py", line 239, in main
    output = trainer.train(
  File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/fsdp_train.py", line 427, in train
    'perplexity': jnp.exp(loss).tolist(),
RuntimeError: Running operations on `Array`s that are not fully addressable by this process (i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager.

Fixed by adding with jax.spmd_mode('allow_all'):

with jax.spmd_mode("allow_all"):
    self.wandb_runtime.log(
        {
            "loss": loss.tolist(),
            "learning_rate": self.scheduler(
                sharded_train_state_.step.tolist()
            ).tolist(),
            "step": sharded_train_state_.step.tolist(),
            "step time": ttl_time,
            "perplexity": jnp.exp(loss).tolist(),
            "accuracy": accuracy.tolist(),
            "avg_accuracy": (sum(accuracies) / len(accuracies)).tolist(),
            "mem_res": mem_res,
        }
    )
erfanzar commented 8 months ago

thank you for sharing the code but are you sure the code is working fine for training (MistralTrain) the model works fine on the interface but in training, after five steps the loss value will rise from 1.05648 to 4.~

erfanzar commented 8 months ago

fixed <3

erfanzar commented 8 months ago

image

JinSeoungwoo commented 8 months ago

thank you for sharing the code but are you sure the code is working fine for training (MistralTrain) the model works fine on the interface but in training, after five steps the loss value will rise from 1.05648 to 4.~

I don't know if the code I posted is appropriate, but it worked fine in my training, so I used it. Do you think I should restart the training with the updated EasyDeL?

erfanzar commented 8 months ago

thank you for sharing the code but are you sure the code is working fine for training (MistralTrain) the model works fine on the interface but in training, after five steps the loss value will rise from 1.05648 to 4.~

I don't know if the code I posted is appropriate, but it worked fine in my training, so I used it. Do you think I should restart the training with the updated EasyDeL?

can you give me more information something like screenshot of W&B process (Loss and Perplexity Only)

JinSeoungwoo commented 8 months ago

can you give me more information something like screenshot of W&B process (Loss and Perplexity Only)

캡처

Here's a picture of loss and perplexity

erfanzar commented 8 months ago

yes you can continue training in case that you haven't see loss higher than 7.~ for more than 10 steps

JinSeoungwoo commented 8 months ago

Is it possible to save the model in the save step only on the main node and put it on the huggingface repo? I edited code to do it

if jax.process_index() == 0

but when it was save step, all training was stopped and not progressing.

erfanzar commented 8 months ago

Is it possible to save the model in the save step only on the main node and put it on the huggingface repo? I edited code to do it

if jax.process_index() == 0

but when it was save step, all training was stopped and not progressing.

this will cause an error and make other threads or nodes dump so in case that you want to do something like this you should put other nodes to sleep or lock them and you can write a simple script that checks if the new ckpt is saved push that to the hugging face repo

JinSeoungwoo commented 8 months ago

Is it possible to save the model in the save step only on the main node and put it on the huggingface repo? I edited code to do it

if jax.process_index() == 0

but when it was save step, all training was stopped and not progressing.

this will cause an error and make other threads or nodes dump so in case that you want to do something like this you should put other nodes to sleep or lock them and you can write a simple script that checks if the new ckpt is saved push that to the hugging face repo

Thank you for the advice. I fixed it!