Open gurveervirk opened 1 year ago
Hey thanks for raising the issue. Can you follow https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#perform-a-auto-metrics-analysis and run it with PT_XLA_DEBUG
and maybe print metrics per step? If it is super slow it is usually because it is recompiling and we should figure out why it is recompiling. RAM usage usually should not keep increasing.
Thanks for the reply. Will do it and let you know.
I am currently running the notebook with the flag set and am also printing the metrics at each step.
Could you also have a look at the code (finetune.py)?
The metrics and profile messages for the first step are as follows:
0%| | 0/310 [00:00<?, ?it/s][Apt-xla-profiler: TransferFromServerTime too frequent: 3 counts during 2 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
step : 0 loss : tensor(7.9881, device='xla:0', grad_fn=
The run has been stuck on this (first step) for the past 3 hours. Printing the versions: torch_version = 2.1.0+cu121 torch_xla_version = 2.1.0+libtpu
This is the full relevant output including profile data and metrics:
0%| | 0/310 [00:00<?, ?it/s]pt-xla-profiler: TransferFromServerTime too frequent: 3 counts during 2 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
step : 0 loss : tensor(7.9881, device='xla:0', grad_fn=
pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
0%| | 1/310 [3:15:00<1004:17:26, 11700.47s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 4 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
0%| | 1/310 [3:15:01<1004:23:27, 11701.64s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 4 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 5 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. 0%| | 1/310 [3:15:01<1004:24:36, 11701.87s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 6 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. next iter...
0%| | 1/310 [3:15:02<1004:29:37, 11702.84s/it]pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 4 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 3 steps pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
0%| | 1/310 [3:15:03<1004:32:35, 11703.42s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 4 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 5 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
0%| | 1/310 [3:15:03<1004:33:17, 11703.55s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 6 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 5 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
0%| | 1/310 [3:15:03<1004:34:58, 11703.88s/it]pt-xla-profiler: TransferFromServerTime too frequent: 635 counts during 6 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 5 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
0%| | 1/310 [3:15:06<1004:49:14, 11706.65s/it]pt-xla-profiler: TransferFromServerTime too frequent: 634 counts during 6 steps
pt-xla-profiler: Op(s) not lowered: aten::nonzero, Please open a GitHub issue with the above op lowering requests.
step : 1 loss : tensor(8.1711, device='xla:0', grad_fn=
SIGTERM received by PID 3970 (TID 3970) on cpu 2 from PID 127; stack trace: SIGTERM received by PID 3968 (TID 3968) on cpu 3 from PID 127; stack trace: SIGTERM received by PID 3969 (TID 3969) on cpu 58 from PID 127; stack trace: PC: @ 0x7ccdf9354da6 (unknown) (unknown) PC: @ 0x7ccdf9354da6 (unknown) (unknown) PC: @ 0x7ccdf9354da6 (unknown) (unknown) @ 0x7cca2b1aa53a 1152 (unknown) @ 0x7cca2b1aa53a 1152 (unknown) @ 0x7cca2b1aa53a 1152 (unknown) @ 0x7ccdf930afd0 (unknown) (unknown) @ 0x7ccdf930afd0 (unknown) (unknown) https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7ccdf9354da6,7ccdf9354da6,7cca2b1aa539, @ 0x7ccdf930afd0 (unknown) (unknown) 7cca2b1aa539,7ccdf930afcf7ccdf930afcfhttps://symbolize.stripped_domain/r/?trace=&map=&map=7ccdf9354da6,abbd016d9542b8098892badc0b19ea68:7cca1e000000-7cca2b3becf0abbd016d9542b8098892badc0b19ea68:7cca1e000000-7cca2b3becf07cca2b1aa539,
7ccdf930afcf&map=abbd016d9542b8098892badc0b19ea68:7cca1e000000-7cca2b3becf0
E1205 09:22:09.119284 3969 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM.
E1205 09:22:09.119295 3968 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM.
E1205 09:22:09.119287 3970 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM.
E1205 09:22:09.615902 3970 process_state.cc:783] RAW: Raising signal 15 with default behavior
E1205 09:22:09.666330 3968 process_state.cc:783] RAW: Raising signal 15 with default behavior
E1205 09:22:09.744279 3969 process_state.cc:783] RAW: Raising signal 15 with default behavior
Traceback (most recent call last):
File "/kaggle/working/florabert/scripts/1-modeling/finetune.py", line 220, in
There are a couple metrics you dump, first one looks OK
Metric: CompileTime
TotalSamples: 5
Accumulator: 06s572ms100.758us
ValueRate: 950ms799.286us / second
Rate: 0.852281 / second
Percentiles: 1%=024ms472.870us; 5%=024ms472.870us; 10%=024ms472.870us; 20%=025ms067.999us; 50%=087ms213.060us; 80%=05s050ms996.673us; 90%=05s050ms996.673us; 95%=05s050ms996.673us; 99%=05s050ms996.673us
Metric: ExecuteTime
TotalSamples: 8
Accumulator: 145ms336.276us
ValueRate: 024ms352.694us / second
Rate: 1.34049 / second
Percentiles: 1%=001ms136.278us; 5%=001ms136.278us; 10%=001ms136.278us; 20%=001ms301.364us; 50%=002ms970.877us; 80%=031ms829.933us; 90%=104ms480.334us; 95%=104ms480.334us; 99%=104ms480.334us
where you compile 5 times and execute 8 times. It is normal to compile a couple times in the beginning of the training. However if I look at the last one.
Metric: CompileTime
TotalSamples: 331
Accumulator: 05h42m14s504ms505.931us
ValueRate: 01s447ms883.923us / second
Rate: 0.0282823 / second
Percentiles: 1%=025ms067.999us; 5%=05s050ms996.673us; 10%=46s890ms058.574us; 20%=48s877ms971.919us; 50%=55s575ms514.354us; 80%=58s507ms508.401us; 90%=60s584ms563.204us; 95%=01m01s644ms909.657us; 99%=01m02s067ms278.600us
Metric: ExecuteTime
TotalSamples: 644
Accumulator: 23m01s131ms410.269us
ValueRate: 118ms980.164us / second
Rate: 0.0550123 / second
Percentiles: 1%=896.992us; 5%=992.987us; 10%=001ms102.272us; 20%=001ms288.992us; 50%=719ms605.355us; 80%=05s564ms161.242us; 90%=06s098ms615.621us; 95%=07s158ms827.174us; 99%=09s799ms785.496us
it seems like majority of time is spent on recompiling. https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/ might help you have some insights on how lazy tensor works. Essentialy on every step you are execute a program that's slight different(maybe the ops are the same but input shape keep changing) so compiler keeps recompiling. You want to dump the IR or HLO for every step and compare them. They should be identical.
Will try and let you know.
I have a couple of .hlo.0 files as suggested. How do I share them with you?
Is there anything else I should try?
Sorry let me try to take a look today.
The HLO.0 files only contians a few HLO so it would be hard for me to figure out why it recompile. However I did notice one thing
[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
embed (/kaggle/working/florabert/module/florabert/models.py:72)
forward (/kaggle/working/florabert/module/florabert/models.py:67)
_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)
_wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)
forward (/kaggle/working/florabert/module/florabert/models.py:250)
_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)
_wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)
_mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:185)
__call__ (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178)
_thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68)
run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58)
_worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83)
run (/usr/local/lib/python3.10/threading.py:953)
_bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016)
_bootstrap (/usr/local/lib/python3.10/threading.py:973)
this is the trigger of the one execution, and it is not by the mark_step
(where most "normal" execution should be)
embed (/kaggle/working/florabert/module/florabert/models.py:72)
I took a look and it seems to be
attention_mask[input_ids == self.start_token_idx] = 0
I think input_ids == self.start_token_idx
will force a execution which makes things really slow. It will also poetically introudce dynamic shape. @Liyang90 do you have any recommendation of how to convert this kind of indexing ops into other pytorch ops likes index_select
?
If you can use nightly following https://github.com/pytorch/xla#python-packages,you can try out our new debugging tool in https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#compilation--execution-analysis to understand why compilation/execution happens
The HLO.0 files only contians a few HLO so it would be hard for me to figure out why it recompile. However I did notice one thing
[ScheduleSyncTensorsGraph] TensorsGraphInfo: embed (/kaggle/working/florabert/module/florabert/models.py:72) forward (/kaggle/working/florabert/module/florabert/models.py:67) _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527) _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518) forward (/kaggle/working/florabert/module/florabert/models.py:250) _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527) _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518) _mp_fn (/kaggle/working/florabert/scripts/1-modeling/finetune.py:185) __call__ (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:178) _thread_fn (/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:68) run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58) _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83) run (/usr/local/lib/python3.10/threading.py:953) _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016) _bootstrap (/usr/local/lib/python3.10/threading.py:973)
this is the trigger of the one execution, and it is not by the
mark_step
(where most "normal" execution should be)embed (/kaggle/working/florabert/module/florabert/models.py:72)
I took a look and it seems to be
attention_mask[input_ids == self.start_token_idx] = 0
I think
input_ids == self.start_token_idx
will force a execution which makes things really slow. It will also poetically introudce dynamic shape. @Liyang90 do you have any recommendation of how to convert this kind of indexing ops into other pytorch ops likesindex_select
?
This would produce a boolean indexing and likely a non-zero op that produces dynamic shaped index array. It can easily be replaced with a torch.where
op.
I'll try your suggestions and get back to you.
I have tried replacing both attention_mask[input_ids == self.start_token_idx] = 0 and attention_mask[input_ids == self.end_token_idx] = 0 with attention_mask = torch.where(input_ids == self.start_token_idx, 0, attention_mask) and attention_mask = torch.where(input_ids == self.end_token_idx, 0, attention_mask) respectively. My testing with this for ~50 minutes did not show any improvement. It didn't move to the next step at all.
After checking the HLO files once again, I am curious about how many times the optimizer_step comes up as a graph step. Is this normal? I also substituted xm.optimizer_step() with optimizer.step() but it was to no avail. Anything else I should try?
Update:
I tried it again for longer and after ~70 minutes, the speed picked up. Not sure if the substitution helped, but it ended up working regardless. It said around 90 minutes to completion. I believe it was loading all the data into the memory for the 70 minutes, thereby not showing any progress. Is there any other faster or better way to load the data? Or will I have to pay this TPU tax? Also, at the end of the loading process, ~190 GB of RAM was used. Do you suggest increasing the batch size for faster processing?
❓ Questions and Help
I have pretrained roberta-base on dna promoter sequences of plants (working on a project). I am currently trying to finetune it on a downstream task of predicting gene expression values, basically a list of 8 values (corresponding to various tissues) from a single promoter sequence.
This wasn't possible on kaggle's gpu (due to memory restrictions), so I tried to do the same on TPU using pytorch-xla (figured that was the best option). The link to the notebook as well as the datasets used are as follows:
Version 43 is the one using the pytorch-xla code (as far as I could figure out). The data's format is as follows:
sequence \t labels dna_promoter_seq_here list_of_8_values_here
eg: CTCAAGCTGAGCAGTGGGTTTGCTCTGGAGGGGAAGCTCAACGGTGGCGACAAGGAAGAATCTGCTTGCGAGGCGAGCCCTGACGCCGCTGATAGCGACCAAAGGTGGATTAAACAACCCATTTCATCATTCTTCTTCCTTGTTAGTTATGATTCCCACGCTTGCCTTTCATGAATCATGATCCTATATGTATATTGATATTAATCAGTTCTAGAAAGTTCAACAACATTTGAGCATGTCAAAACCTGATCGTTGCCTGTTCCATGTCAACAGTGGATTATAACACGTGCAAATGTAGCTATTTGTGTGAGAAGACGTGTGATCGACTCTTTTTTTATATAGATAGCATTGAGATCAACTGTTTGTATATATCTTGTCATAACATTTTTACTTCGTAGCAACGTACGAGCGTTCACCTATTTGTATATAAGTTATCATGATATTTATAAGTTACCGTTGCAACGCACGGACACTCACCTAGTATAGTTTATGTATTACAGTACTAGGAGCCCTAGGCTTCCAATAACTAGAAAAAGTCCTGGTCAGTCGAACCAAACCACAATCCGACGTATACATTCTGGTTCCCCCACGCCCCCATCCGTTCGATTCA [54.679647, 60.646678, 54.9113, 78.878474, 21.326259, 27.973276, 17.419968, 40.465529]
There's 7,22,000 examples of this kind, ~722 mb in total divided into ~400 mb train, 200 mb test and 100 mb eval. When running the code "finetune.py", all goes well till the training starts (datasets are loaded, processed, etc). But, the latest run took 3+ hrs to get to the next step and the RAM usage kept on increasing. It looked the TPU run was very slow and the run then crashed as it ran out of memory. I have tried accelerate and trainer but those efforts were in vain.
Few questions:
If I pass the model as an arg to xmp.spawn, I end up seeing either of "Check failed: data()->tensor_data" or "RuntimeError: Function AddcmulBackward0 returned an invalid gradient at index 1 - expected device xla:1 but got xla:0". Why?
Kindly guide.