Computational-Morphogenomics-Group / MarkerMap

Marker selection, supervised and unsupervised
MIT License
5 stars 1 forks source link

Optimize sample_subset with Pytorch JIT #9

Open WilsonGregory opened 2 years ago

WilsonGregory commented 2 years ago

The sample_subset function is a bit slow. I think those for loops in doing the k samples really slows it down. We can let the Pytorch JIT maybe optimize it for us.

You can see this slow behavior when you set k = 50 vs 250 for example. Seems like there are other ways to unroll the loops too, but JIT seems the easiest to try out casually. Other modules (like those available in pytorch by default) already use loop unrolling since they are typically written in C++ to avoid the overhead, but that seems like overkill for us.

Similar case: https://discuss.pytorch.org/t/unroll-for-loops-in-forward-pass-for-gpu/62597

Why I think JIT might help: https://spell.ml/blog/pytorch-jit-YBmYuBEAACgAiv71

Is it as simple as tagging or applying a transform on the sample_subset function? : https://pytorch.org/docs/stable/generated/torch.jit.script.html

If it is that simple, we can can have an argument like --use-jit = True / False to enable compiling sample_subset fast. Not sure how to optionally tag with @torch.jit.script but torch.jit.script is also exposed as a function.

Originally posted by @beelze-b in https://github.com/Computational-Morphogenomics-Group/MarkerMap/issues/6#issuecomment-1045216624