pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 454 forks source link

Trying to train StyleGAN2 on TPU #3808

Open harshvardhan96 opened 2 years ago

harshvardhan96 commented 2 years ago

❓ Questions and Help

When I'm running the original code on GPU, the step incrementation is happening pretty quickly, whereas when I'm trying to run the same code with TPU (after making necessary xla changes), the code is running a lot slower and step incrementation is taking a lot of time. Any idea about what I should do to debug this issue ?

Here is the link to the style_train file: https://github.com/NoahVl/Explaining-In-Style-Reproducibility-Study/blob/main/stylex/stylex_train.py

JackCaoG commented 2 years ago

Can you follow https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md, enable the Auto-Metrics Analysis and dump the metrics?

harshvardhan96 commented 2 years ago

@JackCaoG Thank you for your response. Here are the metrics for a few training steps (10, batch_size = 16, number of workers = 4): Could you please help me interpret these metrics ?

Metric: CompileTime
  TotalSamples: 16
  Accumulator: 01m08s608ms146.824us
  ValueRate: 134ms880.491us / second
  Rate: 0.0316839 / second
  Percentiles: 1%=001ms430.094us; 5%=001ms430.094us; 10%=010ms499.381us; 20%=013ms931.663us; 50%=020ms414.246us; 80%=04s525ms614.602us; 90%=23s988ms734.253us; 95%=35s706ms772.082us; 99%=35s706ms772.082us
Metric: DeviceLockWait
  TotalSamples: 310
  Accumulator: 001ms030.907us
  ValueRate: 001.124us / second
  Rate: 0.338129 / second
  Percentiles: 1%=001.000us; 5%=001.426us; 10%=001.543us; 20%=001.710us; 50%=002.274us; 80%=003.678us; 90%=004.939us; 95%=006.750us; 99%=009.127us
Metric: ExecuteTime
  TotalSamples: 300
  Accumulator: 05s862ms799.197us
  ValueRate: 005ms303.050us / second
  Rate: 0.327228 / second
  Percentiles: 1%=001ms397.965us; 5%=002ms591.220us; 10%=002ms793.381us; 20%=002ms010.058us; 50%=003ms353.577us; 80%=017ms026.123us; 90%=063ms994.261us; 95%=109ms683.954us; 99%=111ms082.143us
Metric: InboundData
  TotalSamples: 310
  Accumulator: 1.17GB
  ValueRate: 1.31MB / second
  Rate: 0.338144 / second
  Percentiles: 1%=1.00B; 5%=1.00B; 10%=4.00B; 20%=4.00B; 50%=4.00B; 80%=12.00MB; 90%=12.00MB; 95%=12.00MB; 99%=12.00MB
Metric: InputOutputAliasCount
  TotalSamples: 8
  Accumulator: 8.00
  ValueRate: 0.08 / second
  Rate: 0.0825635 / second
  Percentiles: 1%=1.00; 5%=1.00; 10%=1.00; 20%=1.00; 50%=1.00; 80%=1.00; 90%=1.00; 95%=1.00; 99%=1.00
Metric: IrValueTensorToXlaData
  TotalSamples: 943
  Accumulator: 04s018ms753.258us
  ValueRate: 004ms386.564us / second
  Rate: 1.02956 / second
  Percentiles: 1%=461.894us; 5%=512.181us; 10%=539.278us; 20%=587.233us; 50%=001ms072.181us; 80%=008ms904.688us; 90%=010ms193.602us; 95%=015ms722.595us; 99%=034ms287.416us
Metric: OutboundData
  TotalSamples: 1091
  Accumulator: 3.50GB
  ValueRate: 3.91MB / second
  Rate: 1.11905 / second
  Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=1.00KB; 50%=4.00KB; 80%=12.00MB; 90%=12.00MB; 95%=12.00MB; 99%=12.00MB
Metric: ReleaseDataHandlesTime
  TotalSamples: 665
  Accumulator: 02s368ms080.347us
  ValueRate: 003ms569.104us / second
  Rate: 0.721451 / second
  Percentiles: 1%=233.975us; 5%=268.115us; 10%=288.658us; 20%=358.779us; 50%=579.768us; 80%=001ms256.718us; 90%=002ms066.746us; 95%=003ms995.167us; 99%=183ms072.085us
Metric: TensorToData
  TotalSamples: 1091
  Accumulator: 04s442ms078.170us
  ValueRate: 005ms685.072us / second
  Rate: 1.11904 / second
  Percentiles: 1%=467.975us; 5%=519.412us; 10%=551.158us; 20%=598.741us; 50%=001ms342.054us; 80%=008ms780.810us; 90%=010ms913.380us; 95%=015ms716.337us; 99%=034ms131.452us
Metric: TensorsGraphSize
  TotalSamples: 300
  Accumulator: 376260.00
  ValueRate: 410.40 / second
  Rate: 0.327222 / second
  Percentiles: 1%=1.00; 5%=1.00; 10%=1.00; 20%=3.00; 50%=5.00; 80%=481.00; 90%=8315.00; 95%=9677.00; 99%=9677.00
Metric: TransferFromServerTime
  TotalSamples: 310
  Accumulator: 03s823ms795.170us
  ValueRate: 003ms079.070us / second
  Rate: 0.338144 / second
  Percentiles: 1%=790.766us; 5%=933.418us; 10%=001ms001.707us; 20%=001ms126.042us; 50%=002ms538.247us; 80%=016ms001.271us; 90%=023ms243.230us; 95%=052ms771.219us; 99%=061ms653.194us
Metric: TransferToServerTime
  TotalSamples: 1091
  Accumulator: 04s429ms743.176us
  ValueRate: 005ms671.000us / second
  Rate: 1.11904 / second
  Percentiles: 1%=462.409us; 5%=513.794us; 10%=545.228us; 20%=592.951us; 50%=001ms332.643us; 80%=008ms753.307us; 90%=010ms886.125us; 95%=015ms685.096us; 99%=034ms104.180us
Metric: TransferToServerTransformTime
  TotalSamples: 1091
  Accumulator: 02s548ms966.267us
  ValueRate: 002ms670.807us / second
  Rate: 1.11905 / second
  Percentiles: 1%=055.354us; 5%=069.073us; 10%=077.654us; 20%=099.746us; 50%=309.758us; 80%=003ms861.072us; 90%=004ms815.327us; 95%=005ms879.032us; 99%=018ms922.266us
Counter: CachedCompile
  Value: 284
Counter: CreateCompileHandles
  Value: 15
Counter: CreateDataHandles
  Value: 1391
Counter: CreateXlaTensor
  Value: 340854
Counter: DestroyDataHandles
  Value: 1186
Counter: DestroyXlaTensor
  Value: 340832
Counter: DeviceDataCacheMiss
  Value: 148
Counter: ReleaseDataHandles
  Value: 1186
Counter: UncachedCompile
  Value: 16
Counter: XRTAllocateFromTensor_Empty
  Value: 46
Counter: XrtCompile_Empty
  Value: 1152
Counter: XrtExecuteChained_Empty
  Value: 1152
Counter: XrtExecute_Empty
  Value: 1152
Counter: XrtMemoryInfo_Empty
  Value: 1152
Counter: XrtRead_Empty
  Value: 1152
Counter: XrtReleaseAllocationHandle_Empty
  Value: 1152
Counter: XrtReleaseCompileHandle_Empty
  Value: 1152
Counter: XrtSessionCount
  Value: 10
Counter: XrtSubTuple_Empty
  Value: 1152
Counter: aten::_local_scalar_dense
  Value: 150
Counter: xla::_copy_from
  Value: 10546
Counter: xla::_to_cpu
  Value: 150
Counter: xla::add
  Value: 31443
Counter: xla::addcmul
  Value: 8800
Counter: xla::cat
  Value: 80
Counter: xla::convolution_backward_overrideable
  Value: 4740
Counter: xla::convolution_overrideable
  Value: 9180
Counter: xla::div
  Value: 27920
Counter: xla::empty
  Value: 771
Counter: xla::exp
  Value: 40
Counter: xla::expand
  Value: 22360
Counter: xla::fill_
  Value: 40
Counter: xla::isnan
  Value: 20
Counter: xla::kl_div
  Value: 20
Counter: xla::kl_div_backward
  Value: 20
Counter: xla::l1_loss
  Value: 20
Counter: xla::l1_loss_backward
  Value: 40
Counter: xla::max
  Value: 40
Counter: xla::max_pool2d
  Value: 160
Counter: xla::mean
  Value: 26600
Counter: xla::min
  Value: 40
Counter: xla::mul
  Value: 18940
Counter: xla::native_batch_norm
  Value: 8800
Counter: xla::native_batch_norm_backward
  Value: 4400
Counter: xla::neg
  Value: 4820
Counter: xla::pow
  Value: 600
Counter: xla::relu_
  Value: 8680
Counter: xla::scatter
  Value: 80
Counter: xla::slice
  Value: 200
Counter: xla::sqrt
  Value: 200
Counter: xla::std
  Value: 8800
Counter: xla::sub
  Value: 9100
Counter: xla::sum
  Value: 13700
Counter: xla::threshold_backward
  Value: 4440
Counter: xla::unsqueeze
  Value: 160
Counter: xla::upsample_bilinear2d
  Value: 160
Counter: xla::upsample_bilinear2d_backward
  Value: 80
Counter: xla::view
  Value: 88240
Counter: xla::zero_
  Value: 80
Metric: XrtAllocateFromTensor
  TotalSamples: 1091
  Accumulator: 02s810ms584.524us
  Mean: 002ms742.340us
  StdDev: 003ms222.445us
  Rate: 1.11904 / second
  Percentiles: 25%=169.691us; 50%=496.086us; 80%=003ms118.087us; 90%=004ms928.865us; 95%=005ms379.404us; 99%=017ms145.919us
Metric: XrtCompile
  TotalSamples: 15
  Accumulator: 01m07s238ms245.384us
  Mean: 04s483ms549.692us
  StdDev: 10s891ms830.111us
  Rate: 0.0297036 / second
  Percentiles: 25%=012ms025.080us; 50%=014ms036.743us; 80%=06s106ms862.083us; 90%=23s844ms543.922us; 95%=35s603ms637.580us; 99%=35s603ms637.580us
Metric: XrtExecute
  TotalSamples: 300
  Accumulator: 04s374ms074.808us
  Mean: 015ms580.249us
  StdDev: 030ms453.819us
  Rate: 0.327228 / second
  Percentiles: 25%=001ms208.633us; 50%=002ms741.932us; 80%=015ms533.545us; 90%=062ms022.147us; 95%=108ms552.622us; 99%=110ms794.591us
Metric: XrtReadLiteral
  TotalSamples: 310
  Accumulator: 02s774ms194.743us
  Mean: 006ms723.209us
  StdDev: 012ms607.363us
  Rate: 0.338142 / second
  Percentiles: 25%=499.612us; 50%=667.681us; 80%=008ms606.129us; 90%=013ms502.567us; 95%=044ms764.384us; 99%=053ms962.497us
Metric: XrtReleaseAllocation
  TotalSamples: 665
  Accumulator: 037ms674.542us
  Mean: 055.150us
  StdDev: 127.777us
  Rate: 0.721372 / second
  Percentiles: 25%=012.794us; 50%=023.366us; 80%=053.257us; 90%=099.703us; 95%=195.669us; 99%=781.286us

pt-xla-profiler: ================================================================================
pt-xla-profiler: Unlowered Op usage summary (more of these ops, lower performance)
pt-xla-profiler: Note: _local_scalar_dense typically indicates CPU context access
pt-xla-profiler: --------------------------------------------------------------------------------
pt-xla-profiler: FRAME (count=150):
pt-xla-profiler: Unlowered Op: "xla_cpu_fallback"
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=20):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1560)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=20):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1549)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=20):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1550)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=20):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1552)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=20):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1553)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=20):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1554)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=20):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1604)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: FRAME (count=10):
pt-xla-profiler:   train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1439)
pt-xla-profiler:   __retry_internal (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:33)
pt-xla-profiler:   retry_call (/home/harsh/.local/lib/python3.8/site-packages/retry/api.py:101)
pt-xla-profiler:   run_training (cli.py:78)
pt-xla-profiler:   train_from_folder (cli.py:253)
pt-xla-profiler:   _CallAndUpdateTrace (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:681)
pt-xla-profiler:   _Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:466)
pt-xla-profiler:   Fire (/home/harsh/.local/lib/python3.8/site-packages/fire/core.py:141)
pt-xla-profiler:   main (cli.py:263)
pt-xla-profiler:   <module> (cli.py:269)
pt-xla-profiler: 
pt-xla-profiler: 
pt-xla-profiler: ================================================================================
JackCaoG commented 2 years ago

Hmm, what happened in

train (/home/harsh/thesis-tpu/thesis_1_cuda/stylex/stylex_train.py:1439)

My guess is you are trying to print a XLA tensor so something, accessing XLA tensor before mark_step will cause additional compilation and execution.

Major issue here is that

Metric: CompileTime
  TotalSamples: 16
  Accumulator: 01m08s608ms146.824us
  ValueRate: 134ms880.491us / second
  Rate: 0.0316839 / second
  Percentiles: 1%=001ms430.094us; 5%=001ms430.094us; 10%=010ms499.381us; 20%=013ms931.663us; 50%=020ms414.246us; 80%=04s525ms614.602us; 90%=23s988ms734.253us; 95%=35s706ms772.082us; 99%=35s706ms772.082us
Metric: ExecuteTime
  TotalSamples: 300
  Accumulator: 05s862ms799.197us
  ValueRate: 005ms303.050us / second
  Rate: 0.327228 / second
  Percentiles: 1%=001ms397.965us; 5%=002ms591.220us; 10%=002ms793.381us; 20%=002ms010.058us; 50%=003ms353.577us; 80%=017ms026.123us; 90%=063ms994.261us; 95%=109ms683.954us; 99%=111ms082.143us

the training time is dominated by the CompileTime. Given you said you trained 10 steps, and it compiled 16 times, it is likely you recompile every step. You might want to checkout https://ultrons.medium.com/understanding-the-performance-pytorch-on-cloud-tpus-6b4686905fe4 regarding how to debug this.