pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 462 forks source link

training freezes at certain batch number #2304

Closed joshclancy closed 4 years ago

joshclancy commented 4 years ago

❓ Questions and Help

When reaching a certain number of batches, my training freezes, with no error report is created... it simply freezes. I have actually had this issue before and it was resolved here: [(https://github.com/pytorch/xla/issues/2020)]

Now however I am training a GAN based model with multiple data loaders, multiple discriminators, and multiple optimizers. I have tried a few fixes that came to me, but so far nothing has worked. Is there something XLA specific that must be done in the multiple optimizer situation??? I looked through the docs but did not see anything.

The exact number of batches before freezing is weird. At 15 images per batch, it always stops at 99 batches (1485 images). At 30 images per batch, it always stops at 89 batches (2670 images). As I said... weird.

my training loop: `for image_batch, labels in data_loader:
randShape_batch, labels = next(iter(randShape_loader))

            top_disc_optimizer.zero_grad()
            btm_disc_optimizer.zero_grad()
            optimizer.zero_grad()

            randShape_batch = randShape_batch.to(device)
            small_randShape_batch = makeSmall_4up(randShape_batch

            image_batch = image_batch.to(device)
            labels = labels.to(device).float()

            outputDict = model(outputDict, image_batch)

            #1st Discriminator
            real_small_pred = top_disc(outputDict, small_randShape_batch)
            fake_small_pred = top_disc(outputDict, outputDict["top_gate"].detach())

            top_disc_loss = abs_disc_loss(run, fake_small_pred, real_small_pred, "top_discriminator_loss")
            top_disc_loss.backward()
            xm.optimizer_step(top_disc_optimizer, barrier=True)
            xm.mark_step()

            #2nd Discriminator
            real_pred = btm_disc(outputDict, randShape_batch)
            fake_pred = btm_disc(outputDict, outputDict["btm_gate"].detach())

            disc_lossy = abs_disc_loss(run, fake_pred, real_pred, "btm_discriminator_loss")
            disc_lossy.backward()
            xm.optimizer_step(btm_disc_optimizer, barrier=True)
            xm.mark_step()

            #Train Model
            outputDict["fake_pred"] = btm_disc(outputDict, outputDict["gate"])
            outputDict["fake_small_pred"] = top_disc(outputDict, outputDict["top_gate"])

            # reconstruction error
            lossDict = loss_fn(image_batch, labels, run, outputDict, lossDict)
            lossy = lossDict["loss"]
            lossy.backward()

            xm.optimizer_step(optimizer, barrier=True)
            xm.mark_step()`

I also tried using the parallelLoader method (Which I have not used before) but it gets stuck even earlier (but less predictably) I have seen it stuck at batch 4 and batch 7.

Here is my training loop for that method:

for image_batch, labels in para_loader.per_device_loader(dev):
randShape_batch, labels = next(iter(RS_para_loader.per_device_loader(dev)))

            top_disc_optimizer.zero_grad()
            btm_disc_optimizer.zero_grad()
            optimizer.zero_grad()

            small_randShape_batch = makeSmall_4up(randShape_batch)
            outputDict["small_randShape_Batch"] = small_randShape_batch

            outputDict = model(outputDict, image_batch)

            #1st Discriminator
            real_small_pred = top_disc(outputDict, small_randShape_batch)
            fake_small_pred = top_disc(outputDict, outputDict["top_gate"].detach())

            top_disc_loss = abs_disc_loss(run, fake_small_pred, real_small_pred, "top_discriminator_loss")
            top_disc_loss.backward()
            xm.optimizer_step(top_disc_optimizer)
            xm.mark_step()

            #2nd Discriminator
            real_pred = btm_disc(outputDict, randShape_batch)
            fake_pred = btm_disc(outputDict, outputDict["btm_gate"].detach())

            disc_loss = abs_disc_loss(run, fake_pred, real_pred, "btm_discriminator_loss")
            disc_loss.backward()
            xm.optimizer_step(btm_disc_optimizer)
            xm.mark_step()

            #Train Model
            outputDict["fake_pred"] = btm_disc(outputDict, outputDict["gate"])
            outputDict["fake_small_pred"] = top_disc(outputDict, outputDict["top_gate"])

            # reconstruction error
            lossDict = loss_fn(image_batch, labels, run, outputDict, lossDict)
            lossy = lossDict["loss"]
            lossy.backward()

            xm.optimizer_step(optimizer)
            xm.mark_step()

Any ideas or help would be much appreciated! Thanks!

dlibenzi commented 4 years ago

It is impossible for us to even try to debug based on few lines dropped in an edit box. From our issue report page ...


It is really important for the team to have a quick repro, which requires no setup work.

The quicker is the repro to be run, the higher the chances the bug will be addressed sooner.

The best way to create quick repros is to create a Colab based on the following template:

https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#using-debug_runpy-to-collect-debug-information

Things to avoid in repros is the need to download datasets which require setting up keys or other login information, like Kaggle downloads for example.

Another example are Colab which mount user's Google Drive storages.

Using a fake data generator could be a solution, in case the dataset cannot be easily downloaded without setting up credentials:

https://github.com/pytorch/xla/blob/784b4d4f21751a54be0029a95f47d3896561c2a9/test/test_train_mp_mnist.py#L65

joshclancy commented 4 years ago

Fair enough, I'll work on a Colab example. I tried to use debug_run.py but so far have been unable to (I am still working on it.) In the meantime, here are the metrics from the metrics_report() right before the training freezes.

after 23 epochs training froze:

Metric: CompileTime
  TotalSamples: 66
  Accumulator: 03m58s114ms439.094us
  ValueRate: 869ms729.629us / second
  Rate: 0.321906 / second
  Percentiles: 1%=007ms443.578us; 5%=008ms957.204us; 10%=008ms259.868us; 20%=009ms546.083us; 50%=012ms797.975us; 80%=04s800ms812.593us; 90%=06s128ms561.840us; 95%=12s095ms371.171us; 99%=35s707ms669.786us
Metric: DeviceLockWait
  TotalSamples: 654
  Accumulator: 11s315ms076.945us
  ValueRate: 052ms388.075us / second
  Rate: 3.02798 / second
  Percentiles: 1%=001.362us; 5%=001.584us; 10%=001.750us; 20%=001.987us; 50%=002.609us; 80%=004.401us; 90%=006.173us; 95%=220ms030.495us; 99%=270ms253.957us
Metric: ExecuteTime
  TotalSamples: 558
  Accumulator: 15s053ms958.183us
  ValueRate: 070ms742.922us / second
  Rate: 2.58531 / second
  Percentiles: 1%=002ms060.530us; 5%=002ms287.053us; 10%=002ms400.180us; 20%=003ms522.336us; 50%=003ms829.535us; 80%=008ms097.983us; 90%=053ms862.229us; 95%=232ms119.664us; 99%=281ms117.191us
Metric: InboundData
  TotalSamples: 534
  Accumulator: 647.80MB
  ValueRate: 3.01MB / second
  Rate: 2.4779 / second
  Percentiles: 1%=4.00B; 5%=4.00B; 10%=8.00B; 20%=8.00B; 50%=512.00B; 80%=128.00KB; 90%=5.74MB; 95%=9.00MB; 99%=17.23MB
Metric: InputOutputAliasCount
  TotalSamples: 15
  Accumulator: 4203.00
  ValueRate: 27.39 / second
  Rate: 0.09776 / second
  Percentiles: 1%=1.00; 5%=1.00; 10%=1.00; 20%=1.00; 50%=85.00; 80%=740.00; 90%=740.00; 95%=741.00; 99%=741.00
Metric: IrValueTensorToXlaData
  TotalSamples: 810
  Accumulator: 02s253ms696.841us
  ValueRate: 125ms019.471us / second
  Rate: 44.9531 / second
  Percentiles: 1%=001ms075.930us; 5%=001ms195.509us; 10%=001ms297.725us; 20%=001ms447.743us; 50%=002ms761.716us; 80%=002ms175.752us; 90%=005ms361.781us; 95%=008ms213.463us; 99%=026ms613.204us
Metric: OutboundData
  TotalSamples: 961
  Accumulator: 1016.02MB
  ValueRate: 4.69MB / second
  Rate: 4.43227 / second
  Percentiles: 1%=4.00B; 5%=4.00B; 10%=8.00B; 20%=64.00B; 50%=256.00B; 80%=2.00KB; 90%=36.75KB; 95%=576.00KB; 99%=34.45MB
Metric: ReleaseDataHandlesTime
  TotalSamples: 2138
  Accumulator: 12s700ms543.337us
  ValueRate: 565ms606.327us / second
  Rate: 99.2231 / second
  Percentiles: 1%=480.158us; 5%=567.733us; 10%=632.640us; 20%=759.282us; 50%=001ms076.423us; 80%=002ms584.092us; 90%=002ms359.305us; 95%=003ms304.227us; 99%=217ms219.906us
Metric: TensorsGraphSize
  TotalSamples: 559
  Accumulator: 258873.00
  ValueRate: 1199.02 / second
  Rate: 2.58912 / second
  Percentiles: 1%=3.00; 5%=4.00; 10%=4.00; 20%=5.00; 50%=5.00; 80%=150.00; 90%=1063.00; 95%=2352.00; 99%=6597.00
Metric: TransferFromServerTime
  TotalSamples: 534
  Accumulator: 05s626ms746.076us
  ValueRate: 021ms464.646us / second
  Rate: 2.4779 / second
  Percentiles: 1%=001ms269.181us; 5%=001ms468.228us; 10%=002ms554.540us; 20%=002ms653.990us; 50%=002ms954.155us; 80%=003ms090.376us; 90%=023ms075.265us; 95%=045ms265.117us; 99%=106ms006.748us
Metric: TransferToServerTime
  TotalSamples: 961
  Accumulator: 08s153ms084.449us
  ValueRate: 038ms602.076us / second
  Rate: 4.43214 / second
  Percentiles: 1%=001ms072.837us; 5%=001ms216.206us; 10%=001ms324.757us; 20%=001ms474.060us; 50%=002ms845.732us; 80%=003ms921.902us; 90%=007ms447.833us; 95%=014ms676.903us; 99%=229ms497.124us
Metric: TransferToServerTransformTime
  TotalSamples: 961
  Accumulator: 400ms738.984us
  ValueRate: 002ms843.655us / second
  Rate: 4.43227 / second
  Percentiles: 1%=071.494us; 5%=079.005us; 10%=083.841us; 20%=091.412us; 50%=114.801us; 80%=166.282us; 90%=307.898us; 95%=631.182us; 99%=010ms673.265us
Counter: CachedCompile
  Value: 493
Counter: CreateCompileHandles
  Value: 66
Counter: CreateDataHandles
  Value: 38618
Counter: CreateXlaTensor
  Value: 75154
Counter: DestroyDataHandles
  Value: 37030
Counter: DestroyXlaTensor
  Value: 73921
Counter: MarkStep
  Value: 120
Counter: ReleaseDataHandles
  Value: 37030
Counter: UncachedCompile
  Value: 66
Counter: XRTAllocateFromTensor_Empty
  Value: 95
Counter: XrtCompile_Empty
  Value: 1280
Counter: XrtExecuteChained_Empty
  Value: 1280
Counter: XrtExecute_Empty
  Value: 1280
Counter: XrtRead_Empty
  Value: 1280
Counter: XrtReleaseAllocationHandle_Empty
  Value: 1280
Counter: XrtReleaseCompileHandle_Empty
  Value: 1280
Counter: XrtSessionCount
  Value: 12
Counter: XrtSubTuple_Empty
  Value: 1280
Counter: aten::_local_scalar_dense
  Value: 42
Counter: aten::isnan
  Value: 48
Counter: xla::abs
  Value: 72
Counter: xla::add
  Value: 3432
Counter: xla::add_
  Value: 15778
Counter: xla::addcdiv_
  Value: 4944
Counter: xla::addcmul_
  Value: 4944
Counter: xla::addmm
  Value: 384
Counter: xla::as_strided
  Value: 945
Counter: xla::bernoulli_
  Value: 1008
Counter: xla::cat
  Value: 216
Counter: xla::convolution_backward_overrideable
  Value: 2448
Counter: xla::convolution_overrideable
  Value: 2928
Counter: xla::copy_
  Value: 1413
Counter: xla::div
  Value: 4968
Counter: xla::div_
  Value: 1008
Counter: xla::empty
  Value: 4597
Counter: xla::empty_strided
  Value: 945
Counter: xla::eq
  Value: 48
Counter: xla::expand
  Value: 24
Counter: xla::fill_
  Value: 120
Counter: xla::leaky_relu
  Value: 2040
Counter: xla::leaky_relu_backward
  Value: 2016
Counter: xla::max_pool2d_with_indices
  Value: 144
Counter: xla::max_pool2d_with_indices_backward
  Value: 24
Counter: xla::mean
  Value: 48
Counter: xla::mm
  Value: 624
Counter: xla::mse_loss
  Value: 192
Counter: xla::mse_loss_backward
  Value: 192
Counter: xla::mul
  Value: 2688
Counter: xla::mul_
  Value: 9888
Counter: xla::native_batch_norm
  Value: 1992
Counter: xla::native_batch_norm_backward
  Value: 1488
Counter: xla::neg
  Value: 48
Counter: xla::relu_
  Value: 816
Counter: xla::rsub
  Value: 48
Counter: xla::select
  Value: 24
Counter: xla::sigmoid
  Value: 288
Counter: xla::sigmoid_backward
  Value: 264
Counter: xla::slice
  Value: 384
Counter: xla::sqrt
  Value: 4944
Counter: xla::sub
  Value: 24
Counter: xla::sum
  Value: 408
Counter: xla::t
  Value: 1296
Counter: xla::threshold_backward
  Value: 408
Counter: xla::unsqueeze
  Value: 24
Counter: xla::view
  Value: 1128
Counter: xla::zero_
  Value: 5246
Metric: XrtAllocateFromTensor
  TotalSamples: 1097
  Accumulator: 02s859ms787.490us
  Mean: 002ms775.267us
  StdDev: 004ms402.246us
  Rate: 4.72979 / second
  Percentiles: 25%=316.705us; 50%=479.695us; 80%=001ms348.581us; 90%=003ms056.597us; 95%=007ms464.420us; 99%=025ms832.509us
Metric: XrtCompile
  TotalSamples: 66
  Accumulator: 03m58s602ms771.738us
  Mean: 03s691ms935.935us
  StdDev: 07s082ms782.780us
  Rate: 0.321908 / second
  Percentiles: 25%=007ms135.490us; 50%=009ms877.287us; 80%=04s783ms445.453us; 90%=06s120ms656.876us; 95%=12s070ms813.042us; 99%=35s661ms145.688us
Metric: XrtExecute
  TotalSamples: 558
  Accumulator: 14s972ms006.724us
  Mean: 025ms039.439us
  StdDev: 071ms799.556us
  Rate: 2.58531 / second
  Percentiles: 25%=001ms163.266us; 50%=001ms288.046us; 80%=004ms248.256us; 90%=051ms778.122us; 95%=230ms858.872us; 99%=278ms294.446us
Metric: XrtExecutorEvict
  TotalSamples: 0
  Accumulator: nanB
  Mean: nanB
  StdDev: nanB
  Percentiles: 
Metric: XrtReadLiteral
  TotalSamples: 534
  Accumulator: 935ms645.899us
  Mean: 002ms750.273us
  StdDev: 004ms753.592us
  Rate: 2.47806 / second
  Percentiles: 25%=499.495us; 50%=616.926us; 80%=869.396us; 90%=005ms751.801us; 95%=008ms448.582us; 99%=018ms730.832us
Metric: XrtReleaseAllocation
  TotalSamples: 2138
  Accumulator: 496ms322.786us
  Mean: 254.052us
  StdDev: 504.556us
  Rate: 99.2224 / second
  Percentiles: 25%=016.489us; 50%=038.218us; 80%=323.149us; 90%=894.542us; 95%=001ms324.177us; 99%=003ms605.010us
joshclancy commented 4 years ago

Whooo! Got it working.

I was working on a Colab version for others to play with, but couldn't get the error to occur. After a whole lot of cursing and attempted debugging, I found this old post: https://github.com/pytorch/xla/issues/1562

Turns out I had the same problem, weights and biases logging of my model which was causing the problem.

DevJake commented 2 years ago

I'm having the same issue as here and https://github.com/pytorch/xla/issues/1178, https://github.com/pytorch/xla/issues/2749, https://github.com/pytorch/xla/issues/1562 and https://github.com/huggingface/accelerate/issues/287. No form of logs or output is generated, the program just hangs. I have verified that it is the call to accelerate.backward(loss) that hangs. The model also hangs if this is exchanged for loss.backward().

I have completely removed all and any wandb content from my code, but the problem still exists. I'm really not sure what to do, or what a possible fix might be. Switching over accelerate to use the CPU as the device, and the problem persists.

The exact line wherein the issue occurs can be found here. Many thanks in advance to anyone who can offer their help.

JackCaoG commented 2 years ago

Switching over accelerate to use the CPU as the device Do you mean XLA:CPU or native CPU?

DevJake commented 2 years ago

I am switching to CPU via accelerate config, so I would assume that is native CPU? However, be it a single CPU, multi-CPU, GPU or TPU, the same issue occurs. My understanding of how accelerate and XLA link together isn't clear. I assume accelerate just interacts with XLA behind the scenes?

Is there perhaps some way to force an error message?

JackCaoG commented 2 years ago

Oh OK. Thanks for confirming, if it is even the issue when you use native CPU(which seems to be the case), I think it is better to open the issue in huggingface as it is not specified to pytorch/xla.

DevJake commented 2 years ago

Good thinking. I'll open an issue there and link to this one. I might open an issue here as well, since running loss.backward() still produces the same error; circumventing HuggingFace's Accelerator implementation (seemingly) changes nothing. Thanks for your help!

DevJake commented 2 years ago

It appears I have managed to fix this! All I had to do was build Accelerator from source.

For anyone else who encounters this issue or similar, I have included the contents of my requirements.txt file:

absl-py==1.1.0
accelerate==0.13.0.dev0
astunparse==1.6.3
attrs==19.3.0
Automat==0.8.0
blinker==1.4
cachetools==5.2.0
certifi==2022.6.15
chardet==3.0.4
charset-normalizer==2.1.0
Click==7.0
cloud-init==22.2
cloud-tpu-client==0.10
colorama==0.4.3
command-not-found==0.3
commonmark==0.9.1
configobj==5.0.6
constantly==15.1.0
cryptography==2.8
Cython==0.29.14
dbus-python==1.2.16
distlib==0.3.4
distro==1.4.0
distro-info===0.23ubuntu1
docker-pycreds==0.4.0
einops==0.4.1
ema-pytorch==0.0.10
entrypoints==0.3
filelock==3.7.1
flatbuffers==2.0
future==0.18.2
gast==0.4.0
gitdb==4.0.9
GitPython==3.1.27
google-api-core==1.31.6
google-api-python-client==1.8.0
google-auth==2.9.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.56.3
grpcio==1.47.0
h5py==3.7.0
httplib2==0.20.4
hyperlink==19.0.0
idna==3.3
importlib-metadata==4.12.0
incremental==16.10.1
intel-openmp==2022.1.0
Jinja2==2.10.1
jsonpatch==1.22
jsonpointer==2.0
jsonschema==3.2.0
keras==2.9.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
keyring==18.0.1
language-selector==0.1
launchpadlib==1.10.13
lazr.restfulclient==0.14.2
lazr.uri==1.0.3
libclang==14.0.1
libtpu-nightly==0.1.dev20220518
Markdown==3.3.7
MarkupSafe==1.1.0
mkl==2022.1.0
mkl-include==2022.1.0
mock==4.0.3
more-itertools==4.2.0
netifaces==0.10.4
numpy==1.23.0
oauth2client==4.1.3
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==21.3
pathtools==0.1.2
pexpect==4.6.0
Pillow==9.2.0
platformdirs==2.5.2
promise==2.3
protobuf==3.20.1
psutil==5.9.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.13.0
PyGObject==3.36.0
PyHamcrest==1.9.0
PyJWT==1.7.1
pymacaroons==0.13.0
PyNaCl==1.3.0
pyOpenSSL==19.0.0
pyparsing==3.0.9
pyrsistent==0.15.5
pyserial==3.4
python-apt==2.0.0+ubuntu0.20.4.7
python-debian===0.1.36ubuntu1
pytz==2022.1
PyYAML==5.4.1
requests==2.28.1
requests-oauthlib==1.3.1
requests-unixsocket==0.2.0
rich==12.5.1
rsa==4.8
SecretStorage==2.3.1
sentry-sdk==1.9.5
service-identity==18.1.0
setproctitle==1.3.2
shortuuid==1.0.9
simplejson==3.16.0
six==1.16.0
smmap==5.0.0
sos==4.3
ssh-import-id==5.10
systemd-python==234
tbb==2021.6.0
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.10.0
tensorflow-estimator==2.9.0
tensorflow-io-gcs-filesystem==0.26.0
termcolor==1.1.0
torch==1.12.1
torchvision==0.13.1
tqdm==4.64.0
Twisted==18.9.0
typing-extensions==4.2.0
ubuntu-advantage-tools==27.8
ufw==0.36
unattended-upgrades==0.1
uritemplate==3.0.1
urllib3==1.26.9
virtualenv==20.15.1
wadllib==1.3.3
wandb==0.13.2
Werkzeug==2.1.2
wrapt==1.14.1
zipp==1.0.0
zope.interface==4.7.1

A copy of the updated requirements.txt is available here.

I executed the following commands:

  1. sudo pip uninstall torch_xla
  2. sudo pip uninstall accelerate
  3. pip install git+https://github.com/huggingface/accelerate

Command 3. is critical and will build Accelerate from its source code, rather than the prebuilt distribution builds. I am still none the wiser as to what caused this issue in the first place, and I will probably go ahead with opening a bug report over at their repository. The lack of error message is especially frustrating for attempting to resolve this.

I'm glad this is finally resolved. My thesis is due very soon, and to have such a critical model-breaking issue this close was a big concern.

🤗

JackCaoG commented 2 years ago

@DevJake Nice! Glad that this worked out for you.