Closed jsadler2 closed 4 years ago
When I run Xiaowei's model I get batch running times of 0.05 seconds per training batch.
When I run my tf2 version, I get running times of 4.5 ish seconds! 100x slower
these are on my local machine
There is a whole SO thread on TF2 vs TF1: https://stackoverflow.com/questions/58441514/why-is-tensorflow-2-much-slower-than-tensorflow-1
Per that thread, I tried turning off eager mode with tf.compat.v1.disable_eager_execution()
. When I did this, it never actually ran. It just sat there for like 20 minutes before I killed it.
Then I tried wrapping my training loop in a function (train_model
) and adding the @tf.function
decorator to the function. And it also just sits there foreevvveerrr.
After some changes (https://github.com/jsadler2/drb-dl-model/pull/9), I got it running a lot faster and it's learning faster too. Here are updated results. Mine (mse):
Train on 1008 samples
Epoch 1/3
42/1008 [>.............................] - ETA: 40:50 - loss: 0.6907
84/1008 [=>............................] - ETA: 19:36 - loss: 0.8269
126/1008 [==>...........................] - ETA: 12:31 - loss: 0.8656
168/1008 [====>.........................] - ETA: 8:58 - loss: 0.8490
210/1008 [=====>........................] - ETA: 6:51 - loss: 0.8381
252/1008 [======>.......................] - ETA: 5:25 - loss: 0.8239
294/1008 [=======>......................] - ETA: 4:24 - loss: 0.8210
336/1008 [=========>....................] - ETA: 3:38 - loss: 0.8665
378/1008 [==========>...................] - ETA: 3:03 - loss: 0.8636
420/1008 [===========>..................] - ETA: 2:34 - loss: 0.8865
462/1008 [============>.................] - ETA: 2:10 - loss: 0.9216
504/1008 [==============>...............] - ETA: 1:51 - loss: 0.9118
546/1008 [===============>..............] - ETA: 1:34 - loss: 0.8994
588/1008 [================>.............] - ETA: 1:20 - loss: 0.9007
630/1008 [=================>............] - ETA: 1:07 - loss: 0.9033
672/1008 [===================>..........] - ETA: 56s - loss: 0.9013
714/1008 [====================>.........] - ETA: 46s - loss: 0.9148
756/1008 [=====================>........] - ETA: 38s - loss: 0.9387
798/1008 [======================>.......] - ETA: 30s - loss: 0.9304
840/1008 [========================>.....] - ETA: 23s - loss: 0.9655
882/1008 [=========================>....] - ETA: 16s - loss: 0.9611
924/1008 [==========================>...] - ETA: 10s - loss: 0.9652
966/1008 [===========================>..] - ETA: 5s - loss: 0.9554
1008/1008 [==============================] - 118s 117ms/sample - loss: 0.9346
Epoch 2/3
42/1008 [>.............................] - ETA: 19s - loss: 0.6072
84/1008 [=>............................] - ETA: 17s - loss: 0.5916
126/1008 [==>...........................] - ETA: 15s - loss: 0.6118
168/1008 [====>.........................] - ETA: 14s - loss: 0.6045
210/1008 [=====>........................] - ETA: 14s - loss: 0.5670
252/1008 [======>.......................] - ETA: 13s - loss: 0.5990
294/1008 [=======>......................] - ETA: 12s - loss: 0.5828
336/1008 [=========>....................] - ETA: 12s - loss: 0.6086
378/1008 [==========>...................] - ETA: 11s - loss: 0.6659
420/1008 [===========>..................] - ETA: 10s - loss: 0.6357
462/1008 [============>.................] - ETA: 10s - loss: 0.6418
504/1008 [==============>...............] - ETA: 9s - loss: 0.6181
546/1008 [===============>..............] - ETA: 8s - loss: 0.6188
588/1008 [================>.............] - ETA: 8s - loss: 0.6130
630/1008 [=================>............] - ETA: 7s - loss: 0.6011
672/1008 [===================>..........] - ETA: 6s - loss: 0.5829
714/1008 [====================>.........] - ETA: 5s - loss: 0.5914
756/1008 [=====================>........] - ETA: 4s - loss: 0.5916
798/1008 [======================>.......] - ETA: 4s - loss: 0.5859
840/1008 [========================>.....] - ETA: 3s - loss: 0.5868
882/1008 [=========================>....] - ETA: 2s - loss: 0.5948
924/1008 [==========================>...] - ETA: 1s - loss: 0.5975
966/1008 [===========================>..] - ETA: 0s - loss: 0.5981
1008/1008 [==============================] - 21s 20ms/sample - loss: 0.5942
Epoch 3/3
42/1008 [>.............................] - ETA: 23s - loss: 0.5830
84/1008 [=>............................] - ETA: 22s - loss: 0.5640
126/1008 [==>...........................] - ETA: 21s - loss: 0.5116
168/1008 [====>.........................] - ETA: 20s - loss: 0.5034
210/1008 [=====>........................] - ETA: 19s - loss: 0.4893
252/1008 [======>.......................] - ETA: 18s - loss: 0.4621
294/1008 [=======>......................] - ETA: 17s - loss: 0.4932
336/1008 [=========>....................] - ETA: 16s - loss: 0.5396
378/1008 [==========>...................] - ETA: 16s - loss: 0.5386
420/1008 [===========>..................] - ETA: 15s - loss: 0.6023
462/1008 [============>.................] - ETA: 14s - loss: 0.6054
504/1008 [==============>...............] - ETA: 13s - loss: 0.6003
546/1008 [===============>..............] - ETA: 12s - loss: 0.5738
588/1008 [================>.............] - ETA: 11s - loss: 0.5772
630/1008 [=================>............] - ETA: 10s - loss: 0.5618
672/1008 [===================>..........] - ETA: 9s - loss: 0.5452
714/1008 [====================>.........] - ETA: 8s - loss: 0.5522
756/1008 [=====================>........] - ETA: 7s - loss: 0.5353
798/1008 [======================>.......] - ETA: 6s - loss: 0.5301
840/1008 [========================>.....] - ETA: 4s - loss: 0.5465
882/1008 [=========================>....] - ETA: 3s - loss: 0.5491
924/1008 [==========================>...] - ETA: 2s - loss: 0.5421
966/1008 [===========================>..] - ETA: 1s - loss: 0.5326
1008/1008 [==============================] - 31s 30ms/sample - loss: 0.5300
elapsed time: 0:03:51.420754
Xiaowei's (rmse):
Pretraining starts
==================================
epoch 0: loss 0.0000: loss_s 14.5018: loss_p 0.8718
epoch 0: loss 0.0000: loss_s 14.5994: loss_p 0.7974
epoch 0: loss 0.0000: loss_s 14.8881: loss_p 0.7809
epoch 1: loss 0.0000: loss_s 14.4836: loss_p 0.5517
epoch 1: loss 0.0000: loss_s 14.6092: loss_p 0.7007
epoch 1: loss 0.0000: loss_s 14.8961: loss_p 0.7405
epoch 2: loss 0.0000: loss_s 14.4834: loss_p 0.5302
epoch 2: loss 0.0000: loss_s 14.6051: loss_p 0.6104
epoch 2: loss 0.0000: loss_s 14.8893: loss_p 0.5517
elapsed time: 0:02:47.654620
I got the tf2 version of the rgcn model to run - that is it is training. There's a couple of problems though.