Closed JinSeoungwoo closed 8 months ago
yes for sure, can you tell me which platform you using to train your model on so i can give you step by step example
yes for sure, can you tell me which platform you using to train your model on so i can give you step by step example
I'm using gcp right now with v2-alpha-pod, I want to finetune mistral with v3-32
actually the normal code for training should work just fine but you can use custom mesh size of 1,-1,1 or 2,-1,1
actually the normal code for training should work just fine but you can use custom mesh size of 1,-1,1 or 2,-1,1
I saw https://cloud.google.com/tpu/docs/jax-pods?hl=ko in the TPU POD documentation, can I do the same with EasyDel?
I guess EasyDel Automatically runs the code across all the available devices but yes you have to run that like this
gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="python3 example.py"
make sure to install the jax and jaxlib version >=0.4.10
gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b --worker=all --command="pip install \
--upgrade 'jax[tpu]>0.4.10' \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
I gueese there's a bug in the mistral model and i suggest you don't train your models till I fix that (10~18 hours)
I guess EasyDel Automatically runs the code across all the available devices but yes you have to run that like this
gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b \ --worker=all \ --command="python3 example.py"
make sure to install the jax and jaxlib version >=0.4.10
gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b --worker=all --command="pip install \ --upgrade 'jax[tpu]>0.4.10' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
I gueese there's a bug in the mistral model and i suggest you don't train your models till I fix that (10~18 hours)
I've tried a lot, but after a certain number of steps, the accuracy drops dramatically. Is it possible that I'm setting the parameters incorrectly? Or is it related to a bug?
yes that's related to the bug
yes that's related to the bug
Oh.. I thought I was setting it up wrong until now. I'll have to do the training after the bug is fixed. Thanks for the fix!
fix <3
fix <3
Thank you for fixing error. But I got this error when training on tpu pod
Traceback (most recent call last):
File "/root/tune.py", line 249, in <module>
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/tune.py", line 239, in main
output = trainer.train(
File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/fsdp_train.py", line 427, in train
'perplexity': jnp.exp(loss).tolist(),
RuntimeError: Running operations on `Array`s that are not fully addressable by this process (i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager.
fix <3
Thank you for fixing error. But I got this error when training on tpu pod
Traceback (most recent call last): File "/root/tune.py", line 249, in <module> app.run(main) File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/root/tune.py", line 239, in main output = trainer.train( File "/usr/local/lib/python3.10/dist-packages/EasyDel/trainer/fsdp_train.py", line 427, in train 'perplexity': jnp.exp(loss).tolist(), RuntimeError: Running operations on `Array`s that are not fully addressable by this process (i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager.
Fixed by adding with jax.spmd_mode('allow_all'):
with jax.spmd_mode("allow_all"):
self.wandb_runtime.log(
{
"loss": loss.tolist(),
"learning_rate": self.scheduler(
sharded_train_state_.step.tolist()
).tolist(),
"step": sharded_train_state_.step.tolist(),
"step time": ttl_time,
"perplexity": jnp.exp(loss).tolist(),
"accuracy": accuracy.tolist(),
"avg_accuracy": (sum(accuracies) / len(accuracies)).tolist(),
"mem_res": mem_res,
}
)
thank you for sharing the code but are you sure the code is working fine for training (MistralTrain) the model works fine on the interface but in training, after five steps the loss value will rise from 1.05648 to 4.~
fixed <3
thank you for sharing the code but are you sure the code is working fine for training (MistralTrain) the model works fine on the interface but in training, after five steps the loss value will rise from 1.05648 to 4.~
I don't know if the code I posted is appropriate, but it worked fine in my training, so I used it. Do you think I should restart the training with the updated EasyDeL?
thank you for sharing the code but are you sure the code is working fine for training (MistralTrain) the model works fine on the interface but in training, after five steps the loss value will rise from 1.05648 to 4.~
I don't know if the code I posted is appropriate, but it worked fine in my training, so I used it. Do you think I should restart the training with the updated EasyDeL?
can you give me more information something like screenshot of W&B process (Loss and Perplexity Only)
can you give me more information something like screenshot of W&B process (Loss and Perplexity Only)
Here's a picture of loss and perplexity
yes you can continue training in case that you haven't see loss higher than 7.~ for more than 10 steps
Is it possible to save the model in the save step only on the main node and put it on the huggingface repo? I edited code to do it
if jax.process_index() == 0
but when it was save step, all training was stopped and not progressing.
Is it possible to save the model in the save step only on the main node and put it on the huggingface repo? I edited code to do it
if jax.process_index() == 0
but when it was save step, all training was stopped and not progressing.
this will cause an error and make other threads or nodes dump so in case that you want to do something like this you should put other nodes to sleep or lock them and you can write a simple script that checks if the new ckpt is saved push that to the hugging face repo
Is it possible to save the model in the save step only on the main node and put it on the huggingface repo? I edited code to do it
if jax.process_index() == 0
but when it was save step, all training was stopped and not progressing.
this will cause an error and make other threads or nodes dump so in case that you want to do something like this you should put other nodes to sleep or lock them and you can write a simple script that checks if the new ckpt is saved push that to the hugging face repo
Thank you for the advice. I fixed it!
Can I train using TPUv3-32 pod?