Closed mikowals closed 4 years ago
Check out this pull request on
Review Jupyter notebook visual diffs & provide feedback on notebooks.
Powered by ReviewNB
Great job on the profiling. I had forgotten that we'd done the same for other internal models to avoid the use of scalarized()
and its performance consequences. This also simplifies the training loop code.
Would it be possible to clean the execution state of your notebook, removing the output from the cells and their execution order? I'd like for the notebook to start blank, and it would be preferable for the diff here to only be the code that was changed.
I removed the outputs and execution state. I did not realise those got saved as part of the notebook.
Thanks for bearing with me as I add silly mistakes while hurrying commits. I found your suggestions all very useful.
Keeping accumulated training statistics as tensors and only calling
scalarized()
for printing appears to speed up performance ~2x with XLA devices. The timings below were done on Colab free instances.The changes I made are:
scalarized()
called outside batch loopsStatistics.init(on: device)
to avoid device mismatch errorsupdate()
methodThere may be a simpler way to get the same impact in the training tutorial. Also, I have seen similar Statistics code that might be a more general solution.