MathMarEcol / pdyer_aus_bio

GNU General Public License v3.0
0 stars 0 forks source link

Torch code could be pre compiled to speed up execution #24

Closed PhDyellow closed 1 year ago

PhDyellow commented 1 year ago

This is a low priority issue. Currently, my biggest data sets are processed within an hour, and other code, mostly clustering, is still the limiting factor. I can run the whole pipeline within 24-48 hours without any initial cache. So I will not get any speed up of developer time, the code just runs overnight, even if I can get a 10x speed up on the prediction and extrapolation code.

R torch has jit_compile which takes torchscript and returns compiled kernels.

https://blogs.rstudio.com/ai/posts/2021-08-10-jit-trace-module/#how-to-make-use-of-torch-jit-compilation

Has info on how to do it with jit_trace if I just want to reuse r code.

This would speed up execution, because r would make one call to the gpu and get a bhattacharyya dist vector tensor in return. Currently r makes around a dozen calls or more to the gpu each batch.

Basic idea:

Create a new target that compiles the kernel.

If using jit_compile, rewrite the bhattacharyya_dist_tensor code as torchscript. Have to figure out (by looking at code in torch for indexing with integer vectors, slower than slicing but more flexible, eg can access a subset multiple times) whether to pre build the batches in R, or use torch to build the batches. Once the jit_compile object is built, call it on each batch, possibly even inside bhattacharyya_dist_tensor.

With jit_trace, just run bhattachryya_disttensor through jit trace with dummy data and store the result. Jit_trace will return a call able function, use it with each batch. Need to figure out of it can handle variable sized batches.