PPPLDeepLearning / plasma-python

PPPL deep learning disruption prediction package
http://tigress-web.princeton.edu/~alexeys/docs-web/html/
79 stars 43 forks source link

Speed up multiple inference steps at the end of each epoch #61

Open felker opened 4 years ago

felker commented 4 years ago

Presently, at the end of every epoch, the trained weights are reloaded via a call to Keras.Models.load_weights() 3x separate times in order to evaluate the accuracy on the shots in the training, validation, and testing sets:

https://github.com/PPPLDeepLearning/plasma-python/blob/c82ba61e339882a5af10b1052edc0348e16119f4/plasma/models/mpi_runner.py#L932-L965

Depending on the size of the datasets (number of shots, pulse length, number of signals per shot), network architecture, and hardware, this process might take a significant amount of time. This is especially noticeable if the epoch walltimes are relatively short due to small batch sizes, etc.

For example, for a recent test with d3d_0D on Traverse 4x V100s:

Finished training epoch 3.01 during this session (1.00 epochs passed) in 87.65 seconds
Finished training of epoch 6.01/1000
Begin evaluation of epoch 6.01/1000
[2] loading from epoch 6
[1] loading from epoch 6
[0] loading from epoch 6
[3] loading from epoch 6

128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 13s
896/894 [==============================] - 35s 39ms/step
[0] loading from epoch 6
[3] loading from epoch 6
[1] loading from epoch 6
[2] loading from epoch 6

128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 13s
896/894 [==============================] - 35s 39ms/step
epoch 6, val_roc_30 = 0.85346611872694 val_roc_70 = 0.8345022047574768 val_roc_200 = 0.7913309535951044 val_roc_500 = 0.6638869724330323 va\l_roc_1000 = 0.5480697123316435
[3] loading from epoch 6                                                                                                                    [2] loading from epoch 6
[0] loading from epoch 6                                                                                                                    [1] loading from epoch 6
                                                                                                                                            128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 12s                                                                                         896/894 [==============================] - 35s 39ms/step
epoch 6, test_roc_30 = 0.8400389140546622 test_roc_70 = 0.8236098866020126 test_roc_200 = 0.7792357036451524 test_roc_500 = 0.6798285349466\453 test_roc_1000 = 0.5699692943787431

It seems straightforward to deduplicate the 3x 1:53 load times via a new combined function instead of 2x calls to mpi_make_predictions_and_evaluate_multiple_times() + 1x call to mpi_make_predictions_and_evaluate().