pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

GRU loss not converging on TPU #7331

Closed hungwon closed 4 months ago

hungwon commented 4 months ago

🐛 Bug

I was training GRU Model on TPU V4-8 for the text classification task, but the loss didn't converge and the validation accuracy remained the same. Using RNN and LSTM Model yields the same problem. When I trained the model on the CPU, it trained well; loss kept decreasing and validation accuracy kept increasing. See below the training logs

CASE 1: Using TPU (using only 1 core, but same problem occurs with full core)

101it [00:58,  2.41it/s]
| epoch   1 |   100/  750 batches | loss    1.384 | accuracy    0.252
200it [01:40,  2.68it/s]
| epoch   1 |   200/  750 batches | loss    1.388 | accuracy    0.255
300it [02:30,  3.47it/s]
| epoch   1 |   300/  750 batches | loss    1.384 | accuracy    0.254
401it [03:02,  3.38it/s]
| epoch   1 |   400/  750 batches | loss    1.393 | accuracy    0.255
499it [03:32,  7.20it/s]
| epoch   1 |   500/  750 batches | loss    1.385 | accuracy    0.254
601it [03:59,  2.47it/s]
| epoch   1 |   600/  750 batches | loss    1.387 | accuracy    0.253
703it [04:28,  5.07it/s]
| epoch   1 |   700/  750 batches | loss    1.385 | accuracy    0.252
750it [04:45,  2.63it/s]
------------------------------------------------------------
| end of epoch   1 | valid accuracy    0.252 
------------------------------------------------------------
~~~
------------------------------------------------------------
101it [00:38,  1.81it/s]
| epoch   4 |   100/  750 batches | loss    1.392 | accuracy    0.259
200it [01:14,  4.99it/s]
| epoch   4 |   200/  750 batches | loss    1.388 | accuracy    0.254
301it [01:48,  2.63it/s]
| epoch   4 |   300/  750 batches | loss    1.385 | accuracy    0.252
403it [02:23,  4.71it/s]
| epoch   4 |   400/  750 batches | loss    1.386 | accuracy    0.250
503it [02:53,  6.41it/s]
| epoch   4 |   500/  750 batches | loss    1.386 | accuracy    0.250
601it [03:21,  5.64it/s]
| epoch   4 |   600/  750 batches | loss    1.387 | accuracy    0.248
701it [03:53,  4.64it/s]
| epoch   4 |   700/  750 batches | loss    1.383 | accuracy    0.248
750it [04:13,  2.96it/s]
------------------------------------------------------------
| end of epoch   4 | valid accuracy    0.252 
------------------------------------------------------------
102it [00:39,  4.65it/s]
| epoch   5 |   100/  750 batches | loss    1.386 | accuracy    0.255
201it [01:13,  2.91it/s]
| epoch   5 |   200/  750 batches | loss    1.389 | accuracy    0.251
301it [01:49,  2.32it/s]
| epoch   5 |   300/  750 batches | loss    1.384 | accuracy    0.252
400it [02:28,  2.73it/s]
| epoch   5 |   400/  750 batches | loss    1.386 | accuracy    0.252
502it [02:58,  4.92it/s]
| epoch   5 |   500/  750 batches | loss    1.391 | accuracy    0.252
601it [03:31,  4.56it/s]
| epoch   5 |   600/  750 batches | loss    1.383 | accuracy    0.252
701it [04:03,  4.10it/s]
| epoch   5 |   700/  750 batches | loss    1.388 | accuracy    0.252
750it [04:17,  2.92it/s]
------------------------------------------------------------
| end of epoch   5 | valid accuracy    0.253 
------------------------------------------------------------

CASE 2: Using CPU

117it [00:01, 103.56it/s]
| epoch   1 |   100/  750 batches | loss    1.370 | accuracy    0.314
216it [00:02, 103.94it/s]
| epoch   1 |   200/  750 batches | loss    1.311 | accuracy    0.364
315it [00:03, 104.35it/s]
| epoch   1 |   300/  750 batches | loss    1.208 | accuracy    0.396
414it [00:04, 103.02it/s]
| epoch   1 |   400/  750 batches | loss    1.244 | accuracy    0.416
513it [00:05, 102.56it/s]
| epoch   1 |   500/  750 batches | loss    1.122 | accuracy    0.436
621it [00:06, 99.37it/s] 
| epoch   1 |   600/  750 batches | loss    1.074 | accuracy    0.451
715it [00:07, 98.45it/s] 
| epoch   1 |   700/  750 batches | loss    1.060 | accuracy    0.463
750it [00:07, 101.30it/s]
------------------------------------------------------------
| end of epoch   1 | valid accuracy    0.556 
------------------------------------------------------------
~~~ 
------------------------------------------------------------
120it [00:01, 102.07it/s]
| epoch   4 |   100/  750 batches | loss    0.904 | accuracy    0.679
219it [00:02, 102.01it/s]
| epoch   4 |   200/  750 batches | loss    0.793 | accuracy    0.683
318it [00:03, 103.15it/s]
| epoch   4 |   300/  750 batches | loss    0.814 | accuracy    0.684
417it [00:04, 103.16it/s]
| epoch   4 |   400/  750 batches | loss    0.899 | accuracy    0.687
516it [00:05, 101.30it/s]
| epoch   4 |   500/  750 batches | loss    0.814 | accuracy    0.687
615it [00:06, 102.32it/s]
| epoch   4 |   600/  750 batches | loss    0.758 | accuracy    0.689
714it [00:06, 103.07it/s]
| epoch   4 |   700/  750 batches | loss    0.862 | accuracy    0.690
750it [00:07, 101.92it/s]
------------------------------------------------------------
| end of epoch   4 | valid accuracy    0.695 
------------------------------------------------------------
120it [00:01, 102.87it/s]
| epoch   5 |   100/  750 batches | loss    0.698 | accuracy    0.702
219it [00:02, 102.33it/s]
| epoch   5 |   200/  750 batches | loss    0.732 | accuracy    0.706
318it [00:03, 102.81it/s]
| epoch   5 |   300/  750 batches | loss    0.788 | accuracy    0.703
417it [00:04, 100.79it/s]
| epoch   5 |   400/  750 batches | loss    0.785 | accuracy    0.703
516it [00:05, 102.47it/s]
| epoch   5 |   500/  750 batches | loss    0.657 | accuracy    0.705
615it [00:06, 100.41it/s]
| epoch   5 |   600/  750 batches | loss    0.786 | accuracy    0.706
714it [00:07, 103.50it/s]
| epoch   5 |   700/  750 batches | loss    0.716 | accuracy    0.707
750it [00:07, 101.74it/s]
------------------------------------------------------------
| end of epoch   5 | valid accuracy    0.704 
------------------------------------------------------------

Code

The only difference was 1. the code for setting device and 2. optimization step(ex optimizer.step() vs xm.optimizer_step(optimizer)

CASE 1 Code: TPU version

https://colab.research.google.com/drive/1P238vbj3UvyADC9W03aPIFmQADhJaCMx?usp=sharing

CASE2 Code: CPU version

https://colab.research.google.com/drive/1LLiqPfUM19kGQ50tBVOizLOH5B4iJwa8?usp=sharing

Environment

Reference

https://colab.research.google.com/drive/1HqSnvBnqSEEugKZEvg84pQLX3f2s-ebS#scrollTo=YXmpvx29ta8U

JackCaoG commented 4 months ago

I took a quick look and the model code looks right to me, the step execution will be triggered by

xm.optimizer_step(optimizer, barrier=True) # If you use TPU

The only thing I can think of is that on TPU the matmul by default happens on bf16, even if data type is f32, but I don't think that's the root cause of the issue..

@ManfeiBai can you follow up in this issue? I tried to run the TPU colab bur into

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-2-87a975e0cc13>](https://localhost:8080/#) in <cell line: 2>()
      1 data_dir = './data' # TODO
----> 2 dataset, test_data = AG_NEWS(root=data_dir)

2 frames
[/usr/local/lib/python3.10/dist-packages/torchtext/datasets/ag_news.py](https://localhost:8080/#) in AG_NEWS(root, split)
     61     """
     62     if not is_module_available("torchdata"):
---> 63         raise ModuleNotFoundError(
     64             "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
     65         )

ModuleNotFoundError: Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

I think it can be fixed by just installing the torchdata.

ManfeiBai commented 4 months ago

thanks, @hungwon, build torchdata from source might help,

code(rnn_tpu) passed on my local v4-8: https://gist.github.com/ManfeiBai/bb0cdc0d5c1bae831cd7f2c38e92f375

and my used commands: https://gist.github.com/ManfeiBai/03b7047824ef6abf147202a1c16be8f2

hungwon commented 4 months ago

It seems like that we used different versions of Pytorch and torch_xla. Do I have to change the version to fix the problem? Is there any method to fix the issue with the same versions?

JackCaoG commented 4 months ago

@ManfeiBai I don't think the issue is that torchdata, this is only for us to repo. Can you try to run the colab and check why loss is not going down? You can either use that colab or copy the code into your own TPUVM.

ManfeiBai commented 4 months ago

It seems like that we used different versions of Pytorch and torch_xla. Do I have to change the version to fix the problem? Is there any method to fix the issue with the same versions?

@hungwon, would you mind share your torch, torch_xla, torchtext and torchdata version which you are using locally or in colab? I would try to reproduce with your local config to figure out why loss is not going down

I assume we are using torch/torch_xla 2.3.0 and test locally on v4-8, met the same issue(accuracy didn't change): https://gist.github.com/ManfeiBai/49f19a407ff18b5ebc93c6e2a9ff4b1e with built from source torchdata:

pip3 install torch==2.3.0 torchvision torchaudio torchtext torchdata --index-url https://download.pytorch.org/whl/test/cpu

pip3 install torch_xla[tpu]~=2.3.0 -f https://storage.googleapis.com/libtpu-releases/index.html

pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

git clone https://github.com/pytorch/data.git torchdata
cd torchdata/
pip install .
cd ..

compared above comment's env setup, there is gap between torch/torch_xla 2.3.0 and build-from-source-nightly, I would continue to investigate/bisection the issue


and btw this is my local env with torch, and torch_xla version of above comment, in which accuracy change:

(torch310) root@b7b12c30e894:/pytorch/xla# pip3 list | grep torch
torch                        2.4.0a0+git9554300      /pytorch
torch_xla                    2.4.0+git4aa50f3        /pytorch/xla
torchtext                    0.18.0

please feel free to try these commands to setup the same version of torch and torch_xla: https://gist.github.com/ManfeiBai/331703ce9b299f446625c144a9b1c73d

hungwon commented 4 months ago

I appreciate your help! I'm using the following versions of packages.

torch                        2.3.1
torch-xla                    2.3.0
torchdata                    0.7.1a0+b0e25e2
torchtext                    0.18.0
hungwon commented 4 months ago

@ManfeiBai I have noted the versions of the packages. I was wondering if you could provide an estimate of how long I might need to wait for the next steps or further guidance from your side. Your assistance is greatly appreciated, and I am eager to resolve this issue.

Thank you again for your support!

ManfeiBai commented 4 months ago

Hi, @hungwon, thanks for the patience!

I would recommend to try nightly torch and torchxla to setup env with command:

!pip uninstall torch torchvision torchtext torch_xla libtpu-nightly -y
!pip3 install --pre torch torchvision torchtext --index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
!pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

And also have tried them in this colab notebook with above given CASE 1 Code: TPU version, please feel free to try it: https://colab.sandbox.google.com/drive/1Z-d8SON4O3aNuO744gJhowI3IItpqjZA#scrollTo=ENahPkiJpyqI

thanks for your patience and please feel free to reply with any too

ManfeiBai commented 4 months ago

Hi, @hungwon, sry for the access permission issue, please try to access via these two links for the colab with nightly torch/torch_xla:

or use these commands to setup env: https://gist.github.com/ManfeiBai/96d53837ad7bffc615feb236d10aa725

thanks for your patience, and please let me know if these links are not accessible

hungwon commented 4 months ago

@ManfeiBai Thank You for your help. Besides the problem, I have another issue. Whenever I ran "git clone torch_xla ~~~ ", and imported test_util for logging training_loss in my code, I got the following problem.
Should I just run python setup.py install?

  File "/home/hungwon3626/DeepLearning/src/Trainer/trainer.py", line 15, in <module>
    import xla.torch_xla.test.test_utils as test_utils
  File "/home/hungwon3626/xla/torch_xla/__init__.py", line 152, in <module>
    from .version import __version__
ModuleNotFoundError: No module named 'xla.torch_xla.version'
ManfeiBai commented 4 months ago

@ManfeiBai Thank You for your help. Besides the problem, I have another issue. Whenever I ran "git clone torch_xla ~~~ ", and imported test_util for logging training_loss in my code, I got the following problem. Should I just run python setup.py install?

  File "/home/hungwon3626/DeepLearning/src/Trainer/trainer.py", line 15, in <module>
    import xla.torch_xla.test.test_utils as test_utils
  File "/home/hungwon3626/xla/torch_xla/__init__.py", line 152, in <module>
    from .version import __version__
ModuleNotFoundError: No module named 'xla.torch_xla.version'

it looks like your current path might not catch right torch_xla, would you mind help to confirm with like pip3 list | grep torch or pwd,

and also wanna confirm the xla.torch_xla.test.test_utils, do we mean torch_xla as xla here?

I have a quick try locally:

(7731newjun30) root@b7b12c30e894:/# PJRT_DEVICE=TPU python
Python 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch_xla.test.test_utils as test_utils
>>> 
hungwon commented 4 months ago

@ManfeiBai

  1. I've forked torch_xla and edited logging-related methods. Then, I've installed torch_xla through the code you provided, I'll call it default_torch_xla, and also git-cloned the forked-version-torch_xla. All I want to do is to load only the edited test_utils from the forked-version of torch_xla for logging and use default_torch_xla for others. When I ran the jupyter notebook, the error didn't occur. By the way, here's my project path: /home/hungwon3626/DeepLearning/Test and forked-verison-torch_xla path: /home/hungwon3626/xla

  2. yes, xla represent torch_xla.

ManfeiBai commented 4 months ago

@ManfeiBai

  1. I've forked torch_xla and edited logging-related methods. Then, I've installed torch_xla through the code you provided, I'll call it default_torch_xla, and also git-cloned the forked-version-torch_xla. All I want to do is to load only the edited test_utils from the forked-version of torch_xla for logging and use default_torch_xla for others. When I ran the jupyter notebook, the error didn't occur. By the way, here's my project path: /home/hungwon3626/DeepLearning/Test and forked-verison-torch_xla path: /home/hungwon3626/xla
  2. yes, xla represent torch_xla.

thanks for the confirmation, @hungwon

yes, if you want to use forked-version-torch_xla and use edited test_utils, you might need this link(from Clone the PyTorch repo as per instructions step) to build torch and torch_xla from source, and you also might need to rebuild torchdata from source again

if you have build torch and torch_xla from source, you could skip command to install nightly torch and torch_xla mentioned in this comment;

if you already has torch and torch_xla installed with above commands, please uninstall them before you build torch and torch_xla from source, multi torch/torch_xla in the env at the same time might cause some issue unexpected;

glad to know the error is not in the notebook;

thanks for your confirmation about the path, and the path looks good to me, please let me know if we failed to call torch_xla under DeepLearning/Test/

thanks, since xla represent torch_xla, I might would recommend to try:

import torch_xla as xla
import xla.test.test_utils as test_utils
JackCaoG commented 4 months ago

@hungwon the issue is your run your python inside the xla repo and it tries to use the torch_xla you cloned and not the system one. You can likely fix that problem by do a cd ~/ and then python.

hungwon commented 4 months ago

@ManfeiBai @JackCaoG Thank you!