USGS-R / river-dl

Deep learning model for predicting environmental variables on river systems
Creative Commons Zero v1.0 Universal
21 stars 15 forks source link

tf2 version is way slower, less accurate #8

Closed jsadler2 closed 4 years ago

jsadler2 commented 4 years ago

I got the tf2 version of the rgcn model to run - that is it is training. There's a couple of problems though.

jsadler2 commented 4 years ago
  1. It's a lot slower.

When I run Xiaowei's model I get batch running times of 0.05 seconds per training batch. image

When I run my tf2 version, I get running times of 4.5 ish seconds! 100x slower image

jsadler2 commented 4 years ago

these are on my local machine

jsadler2 commented 4 years ago

There is a whole SO thread on TF2 vs TF1: https://stackoverflow.com/questions/58441514/why-is-tensorflow-2-much-slower-than-tensorflow-1

jsadler2 commented 4 years ago

Per that thread, I tried turning off eager mode with tf.compat.v1.disable_eager_execution(). When I did this, it never actually ran. It just sat there for like 20 minutes before I killed it.

jsadler2 commented 4 years ago

Then I tried wrapping my training loop in a function (train_model) and adding the @tf.function decorator to the function. And it also just sits there foreevvveerrr.

jsadler2 commented 4 years ago

After some changes (https://github.com/jsadler2/drb-dl-model/pull/9), I got it running a lot faster and it's learning faster too. Here are updated results. Mine (mse):

Train on 1008 samples
Epoch 1/3

  42/1008 [>.............................] - ETA: 40:50 - loss: 0.6907
  84/1008 [=>............................] - ETA: 19:36 - loss: 0.8269
 126/1008 [==>...........................] - ETA: 12:31 - loss: 0.8656
 168/1008 [====>.........................] - ETA: 8:58 - loss: 0.8490 
 210/1008 [=====>........................] - ETA: 6:51 - loss: 0.8381
 252/1008 [======>.......................] - ETA: 5:25 - loss: 0.8239
 294/1008 [=======>......................] - ETA: 4:24 - loss: 0.8210
 336/1008 [=========>....................] - ETA: 3:38 - loss: 0.8665
 378/1008 [==========>...................] - ETA: 3:03 - loss: 0.8636
 420/1008 [===========>..................] - ETA: 2:34 - loss: 0.8865
 462/1008 [============>.................] - ETA: 2:10 - loss: 0.9216
 504/1008 [==============>...............] - ETA: 1:51 - loss: 0.9118
 546/1008 [===============>..............] - ETA: 1:34 - loss: 0.8994
 588/1008 [================>.............] - ETA: 1:20 - loss: 0.9007
 630/1008 [=================>............] - ETA: 1:07 - loss: 0.9033
 672/1008 [===================>..........] - ETA: 56s - loss: 0.9013 
 714/1008 [====================>.........] - ETA: 46s - loss: 0.9148
 756/1008 [=====================>........] - ETA: 38s - loss: 0.9387
 798/1008 [======================>.......] - ETA: 30s - loss: 0.9304
 840/1008 [========================>.....] - ETA: 23s - loss: 0.9655
 882/1008 [=========================>....] - ETA: 16s - loss: 0.9611
 924/1008 [==========================>...] - ETA: 10s - loss: 0.9652
 966/1008 [===========================>..] - ETA: 5s - loss: 0.9554 
1008/1008 [==============================] - 118s 117ms/sample - loss: 0.9346
Epoch 2/3

  42/1008 [>.............................] - ETA: 19s - loss: 0.6072
  84/1008 [=>............................] - ETA: 17s - loss: 0.5916
 126/1008 [==>...........................] - ETA: 15s - loss: 0.6118
 168/1008 [====>.........................] - ETA: 14s - loss: 0.6045
 210/1008 [=====>........................] - ETA: 14s - loss: 0.5670
 252/1008 [======>.......................] - ETA: 13s - loss: 0.5990
 294/1008 [=======>......................] - ETA: 12s - loss: 0.5828
 336/1008 [=========>....................] - ETA: 12s - loss: 0.6086
 378/1008 [==========>...................] - ETA: 11s - loss: 0.6659
 420/1008 [===========>..................] - ETA: 10s - loss: 0.6357
 462/1008 [============>.................] - ETA: 10s - loss: 0.6418
 504/1008 [==============>...............] - ETA: 9s - loss: 0.6181 
 546/1008 [===============>..............] - ETA: 8s - loss: 0.6188
 588/1008 [================>.............] - ETA: 8s - loss: 0.6130
 630/1008 [=================>............] - ETA: 7s - loss: 0.6011
 672/1008 [===================>..........] - ETA: 6s - loss: 0.5829
 714/1008 [====================>.........] - ETA: 5s - loss: 0.5914
 756/1008 [=====================>........] - ETA: 4s - loss: 0.5916
 798/1008 [======================>.......] - ETA: 4s - loss: 0.5859
 840/1008 [========================>.....] - ETA: 3s - loss: 0.5868
 882/1008 [=========================>....] - ETA: 2s - loss: 0.5948
 924/1008 [==========================>...] - ETA: 1s - loss: 0.5975
 966/1008 [===========================>..] - ETA: 0s - loss: 0.5981
1008/1008 [==============================] - 21s 20ms/sample - loss: 0.5942
Epoch 3/3

  42/1008 [>.............................] - ETA: 23s - loss: 0.5830
  84/1008 [=>............................] - ETA: 22s - loss: 0.5640
 126/1008 [==>...........................] - ETA: 21s - loss: 0.5116
 168/1008 [====>.........................] - ETA: 20s - loss: 0.5034
 210/1008 [=====>........................] - ETA: 19s - loss: 0.4893
 252/1008 [======>.......................] - ETA: 18s - loss: 0.4621
 294/1008 [=======>......................] - ETA: 17s - loss: 0.4932
 336/1008 [=========>....................] - ETA: 16s - loss: 0.5396
 378/1008 [==========>...................] - ETA: 16s - loss: 0.5386
 420/1008 [===========>..................] - ETA: 15s - loss: 0.6023
 462/1008 [============>.................] - ETA: 14s - loss: 0.6054
 504/1008 [==============>...............] - ETA: 13s - loss: 0.6003
 546/1008 [===============>..............] - ETA: 12s - loss: 0.5738
 588/1008 [================>.............] - ETA: 11s - loss: 0.5772
 630/1008 [=================>............] - ETA: 10s - loss: 0.5618
 672/1008 [===================>..........] - ETA: 9s - loss: 0.5452 
 714/1008 [====================>.........] - ETA: 8s - loss: 0.5522
 756/1008 [=====================>........] - ETA: 7s - loss: 0.5353
 798/1008 [======================>.......] - ETA: 6s - loss: 0.5301
 840/1008 [========================>.....] - ETA: 4s - loss: 0.5465
 882/1008 [=========================>....] - ETA: 3s - loss: 0.5491
 924/1008 [==========================>...] - ETA: 2s - loss: 0.5421
 966/1008 [===========================>..] - ETA: 1s - loss: 0.5326
1008/1008 [==============================] - 31s 30ms/sample - loss: 0.5300
elapsed time:  0:03:51.420754
jsadler2 commented 4 years ago

Xiaowei's (rmse):

Pretraining starts
==================================
epoch 0: loss 0.0000: loss_s 14.5018: loss_p 0.8718
epoch 0: loss 0.0000: loss_s 14.5994: loss_p 0.7974
epoch 0: loss 0.0000: loss_s 14.8881: loss_p 0.7809
epoch 1: loss 0.0000: loss_s 14.4836: loss_p 0.5517
epoch 1: loss 0.0000: loss_s 14.6092: loss_p 0.7007
epoch 1: loss 0.0000: loss_s 14.8961: loss_p 0.7405
epoch 2: loss 0.0000: loss_s 14.4834: loss_p 0.5302
epoch 2: loss 0.0000: loss_s 14.6051: loss_p 0.6104
epoch 2: loss 0.0000: loss_s 14.8893: loss_p 0.5517
elapsed time:  0:02:47.654620