chrism0dwk / covid19uk

MIT License
11 stars 10 forks source link

Switch tf.while_loop based Multinomial after all. #8

Closed csuter closed 4 years ago

csuter commented 4 years ago

The for loop based approach, for reasons I don't yet understand, incurs XLA compilation times that scale poorly with inference period (at least linearly, maybe as much as quadratically; haven't measured). We are working on incorporating the Multinomial code I added here into TFP's Multinomial distribution, but until that is checked in (should be soon), I wanted to get this in here.

We do take a hit here on the post-compilation iteration times, but I suspect we can improve further. I'm seeing the following numbers; this is for an inference period of length 132 (sorry for the weird number...)

For-loop method:

Run 1: 165.3 seconds Run 2: 1.796 seconds ~= 0.0136 per iter

tf.while_loop method:

Run 1: 18.224 seconds Run 2: 6.886 seconds ~= 0.052 per iter

If the long compile and faster iteration time is preferable, feel free not to merge this PR!

chrism0dwk commented 4 years ago

Thanks Chris. I'm going to keep with the for loop version for now -- I can stomach the startup cost, as I'll most likely be wanting to run the simulation longer in the future.

csuter commented 4 years ago

Sounds good! We have some more work ongoing that might improve relative performance further. Will keep you updated.