pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 484 forks source link

Kaggle TPU Finetuning Roberta Help #6015

Open gurveervirk opened 1 year ago

gurveervirk commented 1 year ago

❓ Questions and Help

I have pretrained roberta-base on dna promoter sequences of plants (working on a project). I am currently trying to finetune it on a downstream task of predicting gene expression values, basically a list of 8 values (corresponding to various tissues) from a single promoter sequence.

This wasn't possible on kaggle's gpu (due to memory restrictions), so I tried to do the same on TPU using pytorch-xla (figured that was the best option). The link to the notebook as well as the datasets used are as follows:

  1. Main Kaggle Notebook
  2. Dataset containing code and data
  3. Dataset on github (contains old code but has the correct structure)

Version 43 is the one using the pytorch-xla code (as far as I could figure out). The data's format is as follows:

sequence \t labels dna_promoter_seq_here list_of_8_values_here

eg: CTCAAGCTGAGCAGTGGGTTTGCTCTGGAGGGGAAGCTCAACGGTGGCGACAAGGAAGAATCTGCTTGCGAGGCGAGCCCTGACGCCGCTGATAGCGACCAAAGGTGGATTAAACAACCCATTTCATCATTCTTCTTCCTTGTTAGTTATGATTCCCACGCTTGCCTTTCATGAATCATGATCCTATATGTATATTGATATTAATCAGTTCTAGAAAGTTCAACAACATTTGAGCATGTCAAAACCTGATCGTTGCCTGTTCCATGTCAACAGTGGATTATAACACGTGCAAATGTAGCTATTTGTGTGAGAAGACGTGTGATCGACTCTTTTTTTATATAGATAGCATTGAGATCAACTGTTTGTATATATCTTGTCATAACATTTTTACTTCGTAGCAACGTACGAGCGTTCACCTATTTGTATATAAGTTATCATGATATTTATAAGTTACCGTTGCAACGCACGGACACTCACCTAGTATAGTTTATGTATTACAGTACTAGGAGCCCTAGGCTTCCAATAACTAGAAAAAGTCCTGGTCAGTCGAACCAAACCACAATCCGACGTATACATTCTGGTTCCCCCACGCCCCCATCCGTTCGATTCA [54.679647, 60.646678, 54.9113, 78.878474, 21.326259, 27.973276, 17.419968, 40.465529]

There's 7,22,000 examples of this kind, ~722 mb in total divided into ~400 mb train, 200 mb test and 100 mb eval. When running the code "finetune.py", all goes well till the training starts (datasets are loaded, processed, etc). But, the latest run took 3+ hrs to get to the next step and the RAM usage kept on increasing. It looked the TPU run was very slow and the run then crashed as it ran out of memory. I have tried accelerate and trainer but those efforts were in vain.

Few questions:

  1. Is my approach correct?
  2. What changes should I make?
  3. Can I run this code using HuggingFace Trainer (was originally used in the code)? If so, how?
  4. Is the RAM usage normal?
  5. Should it take this long?

If I pass the model as an arg to xmp.spawn, I end up seeing either of "Check failed: data()->tensor_data" or "RuntimeError: Function AddcmulBackward0 returned an invalid gradient at index 1 - expected device xla:1 but got xla:0". Why?

Kindly guide.

JackCaoG commented 1 year ago

Hey thanks for raising the issue. Can you follow https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#perform-a-auto-metrics-analysis and run it with PT_XLA_DEBUG and maybe print metrics per step? If it is super slow it is usually because it is recompiling and we should figure out why it is recompiling. RAM usage usually should not keep increasing.

gurveervirk commented 1 year ago

Thanks for the reply. Will do it and let you know.

gurveervirk commented 1 year ago

I am currently running the notebook with the flag set and am also printing the metrics at each step.

Could you also have a look at the code (finetune.py)?

gurveervirk commented 1 year ago

The metrics and profile messages for the first step are as follows:

0%| | 0/310 [00:00<?, ?it/s]pt-xla-profiler: TransferFromServerTime too frequent: 3 counts during 2 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. step : 0 loss : tensor(7.9881, device='xla:0', grad_fn=) Metric: DeviceLockWait TotalSamples: 16 Accumulator: 100.382us ValueRate: 017.098us / second Rate: 2.72533 / second Percentiles: 1%=001.321us; 5%=001.321us; 10%=002.111us; 20%=002.410us; 50%=006.252us; 80%=009.958us; 90%=012.740us; 95%=013.030us; 99%=013.030us Metric: InputOutputAliasCount TotalSamples: 1 Accumulator: 109.00 Percentiles: 1%=109.00; 5%=109.00; 10%=109.00; 20%=109.00; 50%=109.00; 80%=109.00; 90%=109.00; 95%=109.00; 99%=109.00 Metric: IrValueTensorToXlaData TotalSamples: 222 Accumulator: 099ms477.346us ValueRate: 069ms111.324us / second Rate: 154.233 / second Percentiles: 1%=030.073us; 5%=034.346us; 10%=036.746us; 20%=043.819us; 50%=069.401us; 80%=585.133us; 90%=002ms029.857us; 95%=002ms114.637us; 99%=003ms119.452us Metric: TensorToData TotalSamples: 250 Accumulator: 115ms708.983us ValueRate: 080ms693.238us / second Rate: 173.686 / second Percentiles: 1%=025.726us; 5%=031.670us; 10%=035.181us; 20%=042.264us; 50%=065.340us; 80%=580.368us; 90%=002ms046.274us; 95%=002ms171.485us; 99%=003ms332.095us Metric: TensorsGraphSize TotalSamples: 8 Accumulator: 1363.00 ValueRate: 232.16 / second Rate: 1.36267 / second Percentiles: 1%=3.00; 5%=3.00; 10%=3.00; 20%=3.00; 50%=9.00; 80%=109.00; 90%=1124.00; 95%=1124.00; 99%=1124.00 Metric: UnwrapXlaData TotalSamples: 2653 Accumulator: 531.504us ValueRate: 034.055us / second Rate: 186.54 / second Percentiles: 1%=000.037us; 5%=000.040us; 10%=000.040us; 20%=000.044us; 50%=000.058us; 80%=000.142us; 90%=000.426us; 95%=000.539us; 99%=000.914us Metric: WrapXlaData TotalSamples: 483 Accumulator: 001ms046.984us ValueRate: 159.758us / second Rate: 73.7005 / second Percentiles: 1%=000.359us; 5%=000.364us; 10%=000.375us; 20%=000.410us; 50%=001.263us; 80%=002.628us; 90%=003.227us; 95%=003.942us; 99%=010.006us Counter: CachedCompile Value: 3 Counter: CreateXlaTensor Value: 2302 Counter: DestroyLtcTensor Value: 1533 Counter: DestroyXlaTensor Value: 1533 Counter: DeviceDataCacheMiss Value: 22 Counter: RegisterXLAFunctions Value: 1 Counter: UncachedCompile Value: 6 Counter: xla::_copy_from Value: 228 Counter: xla::_propagate_xla_data Value: 98 Counter: xla::_softmax Value: 12 Counter: xla::_softmax_backward_data Value: 6 Counter: xla::_to_copy Value: 229 Counter: xla::_to_cpu Value: 6 Counter: xla::_unsafe_view Value: 96 Counter: xla::add Value: 142 Counter: xla::addcmul Value: 27 Counter: xla::addmm Value: 6 Counter: xla::bernoulli Value: 42 Counter: xla::bmm Value: 48 Counter: xla::clamp Value: 1 Counter: xla::clone Value: 4 Counter: xla::cumsum Value: 2 Counter: xla::detach_copy Value: 321 Counter: xla::div Value: 64 Counter: xla::embedding_dense_backward Value: 3 Counter: xla::empty_strided_symint Value: 7 Counter: xla::empty_symint Value: 281 Counter: xla::eq Value: 6 Counter: xla::expand_copysymint Value: 51 Counter: xla::fill Value: 5 Counter: xla::gelu Value: 12 Counter: xla::gelu_backward Value: 6 Counter: xla::index Value: 6 Counter: xla::indexput Value: 4 Counter: xla::mm Value: 150 Counter: xla::mse_loss Value: 2 Counter: xla::mse_loss_backward Value: 1 Counter: xla::mul Value: 124 Counter: xla::native_batch_norm Value: 26 Counter: xla::native_batch_norm_backward Value: 13 Counter: xla::ne Value: 2 Counter: xla::nonzero Value: 4 Counter: xla::norm Value: 2 Counter: xla::permute_copy Value: 72 Counter: xla::rsub Value: 2 Counter: xla::select_copy Value: 8 Counter: xla::slice_copy Value: 8 Counter: xla::sqrt Value: 1 Counter: xla::sum Value: 69 Counter: xla::t_copy Value: 195 Counter: xla::tanh Value: 4 Counter: xla::tanh_backward Value: 2 Counter: xla::transpose_copy Value: 42 Counter: xla::unsqueeze_copy Value: 9 Counter: xla::view_copysymint Value: 598 Counter: xla::zero Value: 2 Metric: CompileTime TotalSamples: 5 Accumulator: 06s572ms100.758us ValueRate: 950ms799.286us / second Rate: 0.852281 / second Percentiles: 1%=024ms472.870us; 5%=024ms472.870us; 10%=024ms472.870us; 20%=025ms067.999us; 50%=087ms213.060us; 80%=05s050ms996.673us; 90%=05s050ms996.673us; 95%=05s050ms996.673us; 99%=05s050ms996.673us Metric: ExecuteTime TotalSamples: 8 Accumulator: 145ms336.276us ValueRate: 024ms352.694us / second Rate: 1.34049 / second Percentiles: 1%=001ms136.278us; 5%=001ms136.278us; 10%=001ms136.278us; 20%=001ms301.364us; 50%=002ms970.877us; 80%=031ms829.933us; 90%=104ms480.334us; 95%=104ms480.334us; 99%=104ms480.334us Metric: InboundData TotalSamples: 6 Accumulator: 64.00KB ValueRate: 11.14KB / second Rate: 1.04473 / second Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=4.00B; 50%=16.00KB; 80%=16.00KB; 90%=16.00KB; 95%=16.00KB; 99%=16.00KB Metric: OutboundData TotalSamples: 250 Accumulator: 368.41MB ValueRate: 255.95MB / second Rate: 173.684 / second Percentiles: 1%=4.00B; 5%=4.00B; 10%=32.00B; 20%=3.00KB; 50%=3.00KB; 80%=2.25MB; 90%=9.00MB; 95%=9.00MB; 99%=9.00MB Metric: TransferFromServerTime TotalSamples: 6 Accumulator: 064ms267.560us ValueRate: 011ms190.348us / second Rate: 1.04473 / second Percentiles: 1%=894.451us; 5%=894.451us; 10%=894.451us; 20%=992.734us; 50%=001ms200.132us; 80%=001ms233.308us; 90%=059ms920.629us; 95%=059ms920.629us; 99%=059ms920.629us Metric: TransferToServerTime TotalSamples: 250 Accumulator: 112ms027.827us ValueRate: 078ms829.809us / second Rate: 173.684 / second Percentiles: 1%=019.900us; 5%=027.174us; 10%=029.719us; 20%=034.896us; 50%=055.626us; 80%=571.325us; 90%=002ms035.857us; 95%=002ms141.834us; 99%=003ms308.751us Counter: CreateCompileHandles Value: 5 Counter: CreateDataHandles Value: 518 Counter: MarkStep Value: 2 Counter: aten::_local_scalar_dense Value: 2 Counter: aten::nonzero Value: 4

The run has been stuck on this (first step) for the past 3 hours. Printing the versions: torch_version = 2.1.0+cu121 torch_xla_version = 2.1.0+libtpu

gurveervirk commented 1 year ago

This is the full relevant output including profile data and metrics: 0%| | 0/310 [00:00<?, ?it/s]pt-xla-profiler: TransferFromServerTime too frequent: 3 counts during 2 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. step : 0 loss : tensor(7.9881, device='xla:0', grad_fn=) Metric: DeviceLockWait TotalSamples: 16 Accumulator: 100.382us ValueRate: 017.098us / second Rate: 2.72533 / second Percentiles: 1%=001.321us; 5%=001.321us; 10%=002.111us; 20%=002.410us; 50%=006.252us; 80%=009.958us; 90%=012.740us; 95%=013.030us; 99%=013.030us Metric: InputOutputAliasCount TotalSamples: 1 Accumulator: 109.00 Percentiles: 1%=109.00; 5%=109.00; 10%=109.00; 20%=109.00; 50%=109.00; 80%=109.00; 90%=109.00; 95%=109.00; 99%=109.00 Metric: IrValueTensorToXlaData TotalSamples: 222 Accumulator: 099ms477.346us ValueRate: 069ms111.324us / second Rate: 154.233 / second Percentiles: 1%=030.073us; 5%=034.346us; 10%=036.746us; 20%=043.819us; 50%=069.401us; 80%=585.133us; 90%=002ms029.857us; 95%=002ms114.637us; 99%=003ms119.452us Metric: TensorToData TotalSamples: 250 Accumulator: 115ms708.983us ValueRate: 080ms693.238us / second Rate: 173.686 / second Percentiles: 1%=025.726us; 5%=031.670us; 10%=035.181us; 20%=042.264us; 50%=065.340us; 80%=580.368us; 90%=002ms046.274us; 95%=002ms171.485us; 99%=003ms332.095us Metric: TensorsGraphSize TotalSamples: 8 Accumulator: 1363.00 ValueRate: 232.16 / second Rate: 1.36267 / second Percentiles: 1%=3.00; 5%=3.00; 10%=3.00; 20%=3.00; 50%=9.00; 80%=109.00; 90%=1124.00; 95%=1124.00; 99%=1124.00 Metric: UnwrapXlaData TotalSamples: 2653 Accumulator: 531.504us ValueRate: 034.055us / second Rate: 186.54 / second Percentiles: 1%=000.037us; 5%=000.040us; 10%=000.040us; 20%=000.044us; 50%=000.058us; 80%=000.142us; 90%=000.426us; 95%=000.539us; 99%=000.914us Metric: WrapXlaData TotalSamples: 483 Accumulator: 001ms046.984us ValueRate: 159.758us / second Rate: 73.7005 / second Percentiles: 1%=000.359us; 5%=000.364us; 10%=000.375us; 20%=000.410us; 50%=001.263us; 80%=002.628us; 90%=003.227us; 95%=003.942us; 99%=010.006us Counter: CachedCompile Value: 3 Counter: CreateXlaTensor Value: 2302 Counter: DestroyLtcTensor Value: 1533 Counter: DestroyXlaTensor Value: 1533 Counter: DeviceDataCacheMiss Value: 22 Counter: RegisterXLAFunctions Value: 1 Counter: UncachedCompile Value: 6 Counter: xla::_copy_from Value: 228 Counter: xla::_propagate_xla_data Value: 98 Counter: xla::_softmax Value: 12 Counter: xla::_softmax_backward_data Value: 6 Counter: xla::_to_copy Value: 229 Counter: xla::_to_cpu Value: 6 Counter: xla::_unsafe_view Value: 96 Counter: xla::add Value: 142 Counter: xla::addcmul Value: 27 Counter: xla::addmm Value: 6 Counter: xla::bernoulli Value: 42 Counter: xla::bmm Value: 48 Counter: xla::clamp Value: 1 Counter: xla::clone Value: 4 Counter: xla::cumsum Value: 2 Counter: xla::detach_copy Value: 321 Counter: xla::div Value: 64 Counter: xla::embedding_dense_backward Value: 3 Counter: xla::empty_strided_symint Value: 7 Counter: xla::empty_symint Value: 281 Counter: xla::eq Value: 6 Counter: xla::expand_copysymint Value: 51 Counter: xla::fill Value: 5 Counter: xla::gelu Value: 12 Counter: xla::gelu_backward Value: 6 Counter: xla::index Value: 6 Counter: xla::indexput Value: 4 Counter: xla::mm Value: 150 Counter: xla::mse_loss Value: 2 Counter: xla::mse_loss_backward Value: 1 Counter: xla::mul Value: 124 Counter: xla::native_batch_norm Value: 26 Counter: xla::native_batch_norm_backward Value: 13 Counter: xla::ne Value: 2 Counter: xla::nonzero Value: 4 Counter: xla::norm Value: 2 Counter: xla::permute_copy Value: 72 Counter: xla::rsub Value: 2 Counter: xla::select_copy Value: 8 Counter: xla::slice_copy Value: 8 Counter: xla::sqrt Value: 1 Counter: xla::sum Value: 69 Counter: xla::t_copy Value: 195 Counter: xla::tanh Value: 4 Counter: xla::tanh_backward Value: 2 Counter: xla::transpose_copy Value: 42 Counter: xla::unsqueeze_copy Value: 9 Counter: xla::view_copysymint Value: 598 Counter: xla::zero Value: 2 Metric: CompileTime TotalSamples: 5 Accumulator: 06s572ms100.758us ValueRate: 950ms799.286us / second Rate: 0.852281 / second Percentiles: 1%=024ms472.870us; 5%=024ms472.870us; 10%=024ms472.870us; 20%=025ms067.999us; 50%=087ms213.060us; 80%=05s050ms996.673us; 90%=05s050ms996.673us; 95%=05s050ms996.673us; 99%=05s050ms996.673us Metric: ExecuteTime TotalSamples: 8 Accumulator: 145ms336.276us ValueRate: 024ms352.694us / second Rate: 1.34049 / second Percentiles: 1%=001ms136.278us; 5%=001ms136.278us; 10%=001ms136.278us; 20%=001ms301.364us; 50%=002ms970.877us; 80%=031ms829.933us; 90%=104ms480.334us; 95%=104ms480.334us; 99%=104ms480.334us Metric: InboundData TotalSamples: 6 Accumulator: 64.00KB ValueRate: 11.14KB / second Rate: 1.04473 / second Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=4.00B; 50%=16.00KB; 80%=16.00KB; 90%=16.00KB; 95%=16.00KB; 99%=16.00KB Metric: OutboundData TotalSamples: 250 Accumulator: 368.41MB ValueRate: 255.95MB / second Rate: 173.684 / second Percentiles: 1%=4.00B; 5%=4.00B; 10%=32.00B; 20%=3.00KB; 50%=3.00KB; 80%=2.25MB; 90%=9.00MB; 95%=9.00MB; 99%=9.00MB Metric: TransferFromServerTime TotalSamples: 6 Accumulator: 064ms267.560us ValueRate: 011ms190.348us / second Rate: 1.04473 / second Percentiles: 1%=894.451us; 5%=894.451us; 10%=894.451us; 20%=992.734us; 50%=001ms200.132us; 80%=001ms233.308us; 90%=059ms920.629us; 95%=059ms920.629us; 99%=059ms920.629us Metric: TransferToServerTime TotalSamples: 250 Accumulator: 112ms027.827us ValueRate: 078ms829.809us / second Rate: 173.684 / second Percentiles: 1%=019.900us; 5%=027.174us; 10%=029.719us; 20%=034.896us; 50%=055.626us; 80%=571.325us; 90%=002ms035.857us; 95%=002ms141.834us; 99%=003ms308.751us Counter: CreateCompileHandles Value: 5 Counter: CreateDataHandles Value: 518 Counter: MarkStep Value: 2 Counter: aten::_local_scalar_dense Value: 2 Counter: aten::nonzero Value: 4

pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.

0%| | 1/310 [3:15:00<1004:17:26, 11700.47s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 4 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.

0%| | 1/310 [3:15:01<1004:23:27, 11701.64s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 4 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 5 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. 0%| | 1/310 [3:15:01<1004:24:36, 11701.87s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 6 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. next iter...

0%| | 1/310 [3:15:02<1004:29:37, 11702.84s/it]pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 4 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.

0%| | 1/310 [3:15:03<1004:32:35, 11703.42s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 4 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 5 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. 0%| | 1/310 [3:15:03<1004:33:17, 11703.55s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 6 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 5 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. 0%| | 1/310 [3:15:03<1004:34:58, 11703.88s/it]pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 6 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 5 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. 0%| | 1/310 [3:15:06<1004:49:14, 11706.65s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 6 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. step : 1 loss : tensor(8.1711, device='xla:0', grad_fn=) Metric: DeviceLockWait TotalSamples: 1292 Accumulator: 044ms688.403us ValueRate: 004.419us / second Rate: 0.108815 / second Percentiles: 1%=001.582us; 5%=002.083us; 10%=002.579us; 20%=003.134us; 50%=006.712us; 80%=013.148us; 90%=015.776us; 95%=017.041us; 99%=028.520us Metric: InputOutputAliasCount TotalSamples: 3 Accumulator: 323.00 ValueRate: 0.03 / second Rate: 0.000257771 / second Percentiles: 1%=107.00; 5%=107.00; 10%=107.00; 20%=107.00; 50%=107.00; 80%=109.00; 90%=109.00; 95%=109.00; 99%=109.00 Metric: IrValueTensorToXlaData TotalSamples: 226 Accumulator: 101ms718.124us ValueRate: 008.603us / second Rate: 0.0193045 / second Percentiles: 1%=030.073us; 5%=034.346us; 10%=036.746us; 20%=044.197us; 50%=071.611us; 80%=582.246us; 90%=002ms029.857us; 95%=002ms114.637us; 99%=003ms119.452us Metric: TensorToData TotalSamples: 262 Accumulator: 117ms530.780us ValueRate: 009.954us / second Rate: 0.0223796 / second Percentiles: 1%=025.726us; 5%=032.703us; 10%=035.306us; 20%=042.515us; 50%=068.723us; 80%=575.596us; 90%=002ms028.180us; 95%=002ms150.922us; 99%=003ms332.095us Metric: TensorsGraphSize TotalSamples: 644 Accumulator: 884205.00 ValueRate: 75.53 / second Rate: 0.0550124 / second Percentiles: 1%=3.00; 5%=9.00; 10%=9.00; 20%=9.00; 50%=2084.00; 80%=2091.00; 90%=2091.00; 95%=2091.00; 99%=2091.00 Metric: UnwrapXlaData TotalSamples: 49281 Accumulator: 024ms269.418us ValueRate: 109.447us / second Rate: 300.588 / second Percentiles: 1%=000.037us; 5%=000.038us; 10%=000.038us; 20%=000.040us; 50%=000.050us; 80%=000.453us; 90%=000.661us; 95%=000.771us; 99%=001.129us Metric: WrapXlaData TotalSamples: 3483 Accumulator: 010ms952.580us ValueRate: 034.018us / second Rate: 14.9119 / second Percentiles: 1%=000.352us; 5%=000.450us; 10%=000.505us; 20%=000.561us; 50%=000.843us; 80%=001.292us; 90%=002.994us; 95%=003.415us; 99%=005.803us Counter: CachedCompile Value: 313 Counter: CreateXlaTensor Value: 9717 Counter: DestroyLtcTensor Value: 7814 Counter: DestroyXlaTensor Value: 7814 Counter: DeviceDataCacheMiss Value: 28 Counter: RegisterXLAFunctions Value: 1 Counter: UncachedCompile Value: 331 Counter: xla::_copy_from Value: 238 Counter: xla::_propagate_xla_data Value: 1262 Counter: xla::_softmax Value: 24 Counter: xla::_softmax_backward_data Value: 18 Counter: xla::_to_copy Value: 240 Counter: xla::_to_cpu Value: 638 Counter: xla::_unsafe_view Value: 192 Counter: xla::add Value: 948 Counter: xla::addcmul Value: 267 Counter: xla::addmm Value: 12 Counter: xla::bernoulli Value: 84 Counter: xla::bmm Value: 120 Counter: xla::clamp Value: 215 Counter: xla::clone Value: 8 Counter: xla::cumsum Value: 4 Counter: xla::detach_copy Value: 1002 Counter: xla::div Value: 556 Counter: xla::embedding_dense_backward Value: 9 Counter: xla::empty_strided_symint Value: 439 Counter: xla::empty_symint Value: 769 Counter: xla::eq Value: 431 Counter: xla::expand_copysymint Value: 103 Counter: xla::fill Value: 11 Counter: xla::gelu Value: 24 Counter: xla::gelu_backward Value: 18 Counter: xla::index Value: 12 Counter: xla::indexput Value: 8 Counter: xla::mm Value: 378 Counter: xla::mse_loss Value: 4 Counter: xla::mse_loss_backward Value: 3 Counter: xla::mul Value: 956 Counter: xla::native_batch_norm Value: 52 Counter: xla::native_batch_norm_backward Value: 39 Counter: xla::ne Value: 4 Counter: xla::nonzero Value: 8 Counter: xla::norm Value: 430 Counter: xla::permute_copy Value: 168 Counter: xla::rsub Value: 4 Counter: xla::select_copy Value: 16 Counter: xla::slice_copy Value: 16 Counter: xla::sqrt Value: 215 Counter: xla::sum Value: 203 Counter: xla::t_copy Value: 663 Counter: xla::tanh Value: 8 Counter: xla::tanh_backward Value: 6 Counter: xla::transpose_copy Value: 114 Counter: xla::unsqueeze_copy Value: 19 Counter: xla::view_copysymint Value: 1420 Counter: xla::zero Value: 428 Metric: CompileTime TotalSamples: 331 Accumulator: 05h42m14s504ms505.931us ValueRate: 01s447ms883.923us / second Rate: 0.0282823 / second Percentiles: 1%=025ms067.999us; 5%=05s050ms996.673us; 10%=46s890ms058.574us; 20%=48s877ms971.919us; 50%=55s575ms514.354us; 80%=58s507ms508.401us; 90%=60s584ms563.204us; 95%=01m01s644ms909.657us; 99%=01m02s067ms278.600us Metric: ExecuteTime TotalSamples: 644 Accumulator: 23m01s131ms410.269us ValueRate: 118ms980.164us / second Rate: 0.0550123 / second Percentiles: 1%=896.992us; 5%=992.987us; 10%=001ms102.272us; 20%=001ms288.992us; 50%=719ms605.355us; 80%=05s564ms161.242us; 90%=06s098ms615.621us; 95%=07s158ms827.174us; 99%=09s799ms785.496us Metric: InboundData TotalSamples: 640 Accumulator: 129.23KB ValueRate: 11.30B / second Rate: 0.0546717 / second Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=1.00B; 50%=1.00B; 80%=4.00B; 90%=4.00B; 95%=4.00B; 99%=16.00KB Metric: OutboundData TotalSamples: 262 Accumulator: 368.42MB ValueRate: 32.22KB / second Rate: 0.0223796 / second Percentiles: 1%=4.00B; 5%=4.00B; 10%=8.00B; 20%=3.00KB; 50%=3.00KB; 80%=2.25MB; 90%=2.25MB; 95%=9.00MB; 99%=9.00MB Metric: TransferFromServerTime TotalSamples: 640 Accumulator: 23m48s507ms070.396us ValueRate: 117ms818.584us / second Rate: 0.0546717 / second Percentiles: 1%=888.834us; 5%=001ms105.151us; 10%=001ms196.744us; 20%=001ms451.129us; 50%=702ms927.343us; 80%=05s552ms797.989us; 90%=06s162ms230.704us; 95%=07s176ms165.412us; 99%=09s775ms123.856us Metric: TransferToServerTime TotalSamples: 262 Accumulator: 114ms674.374us ValueRate: 009.710us / second Rate: 0.0223796 / second Percentiles: 1%=019.900us; 5%=027.488us; 10%=029.736us; 20%=035.054us; 50%=056.293us; 80%=564.674us; 90%=002ms020.101us; 95%=002ms110.438us; 99%=003ms308.751us Counter: CreateCompileHandles Value: 331 Counter: CreateDataHandles Value: 2883 Counter: MarkStep Value: 6 Counter: aten::_local_scalar_dense Value: 630 Counter: aten::nonzero Value: 8

https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7ccdf9354da6,7ccdf9354da6,7ccdf9354da6,7ccdf930afcf7ccdf930afcf7ccdf930afcf&map=&map=&map=

SIGTERM received by PID 3970 (TID 3970) on cpu 2 from PID 127; stack trace: SIGTERM received by PID 3968 (TID 3968) on cpu 3 from PID 127; stack trace: SIGTERM received by PID 3969 (TID 3969) on cpu 58 from PID 127; stack trace: PC: @ 0x7ccdf9354da6 (unknown) (unknown) PC: @ 0x7ccdf9354da6 (unknown) (unknown) PC: @ 0x7ccdf9354da6 (unknown) (unknown) @ 0x7cca2b1aa53a 1152 (unknown) @ 0x7cca2b1aa53a 1152 (unknown) @ 0x7cca2b1aa53a 1152 (unknown) @ 0x7ccdf930afd0 (unknown) (unknown) @ 0x7ccdf930afd0 (unknown) (unknown) https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7ccdf9354da6,7ccdf9354da6,7cca2b1aa539, @ 0x7ccdf930afd0 (unknown) (unknown) 7cca2b1aa539,7ccdf930afcf7ccdf930afcfhttps://symbolize.stripped_domain/r/?trace=&map=&map=7ccdf9354da6,abbd016d9542b8098892badc0b19ea68:7cca1e000000-7cca2b3becf0abbd016d9542b8098892badc0b19ea68:7cca1e000000-7cca2b3becf07cca2b1aa539,

7ccdf930afcf&map=abbd016d9542b8098892badc0b19ea68:7cca1e000000-7cca2b3becf0 E1205 09:22:09.119284 3969 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM. E1205 09:22:09.119295 3968 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM. E1205 09:22:09.119287 3970 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM. E1205 09:22:09.615902 3970 process_state.cc:783] RAW: Raising signal 15 with default behavior E1205 09:22:09.666330 3968 process_state.cc:783] RAW: Raising signal 15 with default behavior E1205 09:22:09.744279 3969 process_state.cc:783] RAW: Raising signal 15 with default behavior Traceback (most recent call last): File "/kaggle/working/florabert/scripts/1-modeling/finetune.py", line 220, in xmp.spawn(_mp_fn, start_method = 'fork') File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 82, in wrapper return fn(*args, *kwargs) File "/usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn return pjrt.spawn(fn, nprocs, start_method, args) File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 202, in spawn run_multiprocess(spawn_fn, start_method=start_method) File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 82, in wrapper return fn(args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 159, in run_multiprocess replica_results = list( File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 160, in itertools.chain.from_iterable( File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists for element in iterable: File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator yield _result_or_cancel(fs.pop()) File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel return fut.result(timeout) File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.get_result() File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in get_result raise self._exception concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending. WARNING:pt-xla-profiler:================================================================================ WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance) WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access WARNING:pt-xla-profiler:-------------------------------------------------------------------------------- WARNING:pt-xla-profiler:FRAME (count=3456): WARNING:pt-xla-profiler:Unlowered Op: "xla_cpu_fallback" WARNING:pt-xla-profiler: WARNING:pt-xla-profiler: WARNING:pt-xla-profiler:FRAME (count=2296): WARNING:pt-xla-profiler: step (/usr/local/lib/python3.10/site-packages/torch_optimizer/lamb.py:146) WARNING:pt-xla-profiler: wrapper (/usr/local/lib/python3.10/site-packages/torch/optim/optimizer.py:373) WARNING:pt-xla-profiler: wrapper (/usr/local/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:68) WARNING:pt-xla-profiler: optimizer_step (/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py:941) WARNING:pt-xla-profiler: _mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:187) WARNING:pt-xla-profiler: call (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178) WARNING:pt-xla-profiler: _thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58) WARNING:pt-xla-profiler: _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/threading.py:953) WARNING:pt-xla-profiler: _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016) WARNING:pt-xla-profiler: _bootstrap (/usr/local/lib/python3.10/threading.py:973) WARNING:pt-xla-profiler: WARNING:pt-xla-profiler: WARNING:pt-xla-profiler:FRAME (count=1128): WARNING:pt-xla-profiler: step (/usr/local/lib/python3.10/site-packages/torch_optimizer/lamb.py:156) WARNING:pt-xla-profiler: wrapper (/usr/local/lib/python3.10/site-packages/torch/optim/optimizer.py:373) WARNING:pt-xla-profiler: wrapper (/usr/local/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:68) WARNING:pt-xla-profiler: optimizer_step (/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py:941) WARNING:pt-xla-profiler: _mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:187) WARNING:pt-xla-profiler: call (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178) WARNING:pt-xla-profiler: _thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58) WARNING:pt-xla-profiler: _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/threading.py:953) WARNING:pt-xla-profiler: _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016) WARNING:pt-xla-profiler: _bootstrap (/usr/local/lib/python3.10/threading.py:973) WARNING:pt-xla-profiler: WARNING:pt-xla-profiler: WARNING:pt-xla-profiler:FRAME (count=16): WARNING:pt-xla-profiler: embed (/kaggle/working/florabert/module/florabert/models.py:72) WARNING:pt-xla-profiler: forward (/kaggle/working/florabert/module/florabert/models.py:67) WARNING:pt-xla-profiler: _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527) WARNING:pt-xla-profiler: _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518) WARNING:pt-xla-profiler: forward (/kaggle/working/florabert/module/florabert/models.py:250) WARNING:pt-xla-profiler: _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527) WARNING:pt-xla-profiler: _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518) WARNING:pt-xla-profiler: _mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:182) WARNING:pt-xla-profiler: call (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178) WARNING:pt-xla-profiler: _thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58) WARNING:pt-xla-profiler: _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/threading.py:953) WARNING:pt-xla-profiler: _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016) WARNING:pt-xla-profiler: _bootstrap (/usr/local/lib/python3.10/threading.py:973) WARNING:pt-xla-profiler: WARNING:pt-xla-profiler: WARNING:pt-xla-profiler:FRAME (count=16): WARNING:pt-xla-profiler: embed (/kaggle/working/florabert/module/florabert/models.py:73) WARNING:pt-xla-profiler: forward (/kaggle/working/florabert/module/florabert/models.py:67) WARNING:pt-xla-profiler: _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527) WARNING:pt-xla-profiler: _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518) WARNING:pt-xla-profiler: forward (/kaggle/working/florabert/module/florabert/models.py:250) WARNING:pt-xla-profiler: _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527) WARNING:pt-xla-profiler: _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518) WARNING:pt-xla-profiler: _mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:182) WARNING:pt-xla-profiler: call (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178) WARNING:pt-xla-profiler: _thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58) WARNING:pt-xla-profiler: _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83) WARNING:pt-xla-profiler: run (/usr/local/lib/python3.10/threading.py:953) WARNING:pt-xla-profiler: _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016) WARNING:pt-xla-profiler: _bootstrap (/usr/local/lib/python3.10/threading.py:973) WARNING:pt-xla-profiler: WARNING:pt-xla-profiler: WARNING:pt-xla-profiler:================================================================================

JackCaoG commented 1 year ago

There are a couple metrics you dump, first one looks OK

Metric: CompileTime
TotalSamples: 5
Accumulator: 06s572ms100.758us
ValueRate: 950ms799.286us / second
Rate: 0.852281 / second
Percentiles: 1%=024ms472.870us; 5%=024ms472.870us; 10%=024ms472.870us; 20%=025ms067.999us; 50%=087ms213.060us; 80%=05s050ms996.673us; 90%=05s050ms996.673us; 95%=05s050ms996.673us; 99%=05s050ms996.673us

Metric: ExecuteTime
TotalSamples: 8
Accumulator: 145ms336.276us
ValueRate: 024ms352.694us / second
Rate: 1.34049 / second
Percentiles: 1%=001ms136.278us; 5%=001ms136.278us; 10%=001ms136.278us; 20%=001ms301.364us; 50%=002ms970.877us; 80%=031ms829.933us; 90%=104ms480.334us; 95%=104ms480.334us; 99%=104ms480.334us

where you compile 5 times and execute 8 times. It is normal to compile a couple times in the beginning of the training. However if I look at the last one.

Metric: CompileTime
TotalSamples: 331
Accumulator: 05h42m14s504ms505.931us
ValueRate: 01s447ms883.923us / second
Rate: 0.0282823 / second
Percentiles: 1%=025ms067.999us; 5%=05s050ms996.673us; 10%=46s890ms058.574us; 20%=48s877ms971.919us; 50%=55s575ms514.354us; 80%=58s507ms508.401us; 90%=60s584ms563.204us; 95%=01m01s644ms909.657us; 99%=01m02s067ms278.600us
Metric: ExecuteTime
TotalSamples: 644
Accumulator: 23m01s131ms410.269us
ValueRate: 118ms980.164us / second
Rate: 0.0550123 / second
Percentiles: 1%=896.992us; 5%=992.987us; 10%=001ms102.272us; 20%=001ms288.992us; 50%=719ms605.355us; 80%=05s564ms161.242us; 90%=06s098ms615.621us; 95%=07s158ms827.174us; 99%=09s799ms785.496us

it seems like majority of time is spent on recompiling. https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/ might help you have some insights on how lazy tensor works. Essentialy on every step you are execute a program that's slight different(maybe the ops are the same but input shape keep changing) so compiler keeps recompiling. You want to dump the IR or HLO for every step and compare them. They should be identical.

gurveervirk commented 1 year ago

Will try and let you know.

gurveervirk commented 1 year ago

I have a couple of .hlo.0 files as suggested. How do I share them with you?

gurveervirk commented 1 year ago

The saved HLO files for the run

gurveervirk commented 1 year ago

Is there anything else I should try?

JackCaoG commented 12 months ago

Sorry let me try to take a look today.

JackCaoG commented 12 months ago

The HLO.0 files only contians a few HLO so it would be hard for me to figure out why it recompile. However I did notice one thing

[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
  embed (/kaggle/working/florabert/module/florabert/models.py:72)
  forward (/kaggle/working/florabert/module/florabert/models.py:67)
  _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)
  _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)
  forward (/kaggle/working/florabert/module/florabert/models.py:250)
  _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)
  _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)
  _mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:185)
  __call__ (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178)
  _thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68)
  run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58)
  _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83)
  run (/usr/local/lib/python3.10/threading.py:953)
  _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016)
  _bootstrap (/usr/local/lib/python3.10/threading.py:973)

this is the trigger of the one execution, and it is not by the mark_step(where most "normal" execution should be)

  embed (/kaggle/working/florabert/module/florabert/models.py:72)

I took a look and it seems to be

attention_mask[input_ids == self.start_token_idx] = 0

I think input_ids == self.start_token_idx will force a execution which makes things really slow. It will also poetically introudce dynamic shape. @Liyang90 do you have any recommendation of how to convert this kind of indexing ops into other pytorch ops likes index_select?

JackCaoG commented 12 months ago

If you can use nightly following https://github.com/pytorch/xla#python-packages,you can try out our new debugging tool in https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#compilation--execution-analysis to understand why compilation/execution happens

Liyang90 commented 12 months ago

The HLO.0 files only contians a few HLO so it would be hard for me to figure out why it recompile. However I did notice one thing

[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
  embed (/kaggle/working/florabert/module/florabert/models.py:72)
  forward (/kaggle/working/florabert/module/florabert/models.py:67)
  _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)
  _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)
  forward (/kaggle/working/florabert/module/florabert/models.py:250)
  _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)
  _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)
  _mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:185)
  __call__ (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178)
  _thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68)
  run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58)
  _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83)
  run (/usr/local/lib/python3.10/threading.py:953)
  _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016)
  _bootstrap (/usr/local/lib/python3.10/threading.py:973)

this is the trigger of the one execution, and it is not by the mark_step(where most "normal" execution should be)

  embed (/kaggle/working/florabert/module/florabert/models.py:72)

I took a look and it seems to be

attention_mask[input_ids == self.start_token_idx] = 0

I think input_ids == self.start_token_idx will force a execution which makes things really slow. It will also poetically introudce dynamic shape. @Liyang90 do you have any recommendation of how to convert this kind of indexing ops into other pytorch ops likes index_select?

This would produce a boolean indexing and likely a non-zero op that produces dynamic shaped index array. It can easily be replaced with a torch.where op.

gurveervirk commented 11 months ago

I'll try your suggestions and get back to you.

gurveervirk commented 11 months ago

I have tried replacing both attention_mask[input_ids == self.start_token_idx] = 0 and attention_mask[input_ids == self.end_token_idx] = 0 with attention_mask = torch.where(input_ids == self.start_token_idx, 0, attention_mask) and attention_mask = torch.where(input_ids == self.end_token_idx, 0, attention_mask) respectively. My testing with this for ~50 minutes did not show any improvement. It didn't move to the next step at all.

After checking the HLO files once again, I am curious about how many times the optimizer_step comes up as a graph step. Is this normal? I also substituted xm.optimizer_step() with optimizer.step() but it was to no avail. Anything else I should try?

Update:

I tried it again for longer and after ~70 minutes, the speed picked up. Not sure if the substitution helped, but it ended up working regardless. It said around 90 minutes to completion. I believe it was loading all the data into the memory for the 70 minutes, thereby not showing any progress. Is there any other faster or better way to load the data? Or will I have to pay this TPU tax? Also, at the end of the loading process, ~190 GB of RAM was used. Do you suggest increasing the batch size for faster processing?