Closed gvamvou closed 5 years ago
Hi! Can you try this version (I added a couple of debug options):
https://gist.github.com/dlibenzi/f021e40bb0524ce368b58aaf478d689c
Also, export these in the environment:
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
export TF_CPP_VMODULE=tensor=5
Given that you gave us a full repro, if you could include a minimal traindata.pt
file we could try on our side as well (or provide the required input shape so we can use our xu.SampleGenerator
).
Hello and thank you for your answer. I tried your version with the exports you recommended and here's the log log.txt traindata.pt was generated with the generate_sine_wave.py in the repo I sent in the above message.
With this, I was able to get out the metrics report:
https://gist.github.com/dlibenzi/a18ab4b6e675208bf8277a53a0f42be5
Which shows one op we need to lower to XLA (currently goes to pytorch CPU):
Counter: aten::mse_loss
Value: 1
The model uses torch int64 (long) which is not necessary. So either change to int32 (torch int) or:
export XLA_USE_32BIT_LONG=1
While we implement that (will do soon), maybe you can try some other loss function. If you are on nightly (Docker and TPU VM) we have L1Loss.
I was able to get the source I attach below running on Cloud TPU.
TRIM_GRAPH_CHECK_FREQUENCY=100000 TRIM_GRAPH_SIZE=1000000 XLA_USE_32BIT_LONG=1 python ~/lstm.py
The first time it takes a while to compile, but once the compilation is cached (unless shapes or graph changes), the next runs should come pretty quickly. The ExecuteTime looks pretty small, almost too small 😄
Source:
https://gist.github.com/dlibenzi/0be1687922d99f55188f585ac7ac1534
I have disabled the future in this run. 1000 LSTM cells would definitely be too much.
Metrics:
https://gist.github.com/dlibenzi/54fd4433260ff62c715cb8cb9f5d7747
I ran your modified code and after a while it finally worked! I intend on running more complex models (encoder/decoder with attention) using these methods. Thank you very much for your support :smile:
Notice that with pytorch/xla there can be, for some model (like the ones with explicit long loops), a long compilation time at the beginning. Unless the graph (or tensor shapes) change, the next time the compilation will be cached (unless you explicitly restart the TPU node).
I see. Would it be faster if the model was serialized with JIT?
I honestly think native pytorch/GPU might be faster, but we are in the very early stages and we did not do any performance compare.
I understand. Thanks again for all the help :smiley:
Hello and thank you for providing a way to run pytorch on TPUs! I'm trying to test a simple example taken from https://github.com/pytorch/examples/tree/master/time_sequence_prediction I tried to follow the API Guide to make the necessary changes so it can run on a single TPU. These are the changes I made to the train.py script:
The program gets stuck right after loss is calculated. Is there something I might have missed? Thank you for your time!