ramsey-devs / ramsey

Probabilistic deep learning using JAX
https://ramsey.rtfd.io
Apache License 2.0
13 stars 3 forks source link

Addition of SNP (sequential neural processes) #46

Open ojss opened 1 month ago

ojss commented 1 month ago

Hi @dirmeier, thanks for making this amazing codebase! I am currently working on a sequential neural process implementation that can hopefully be added to ramsey at some point. But I am quite new flax/jax and running into some issues with regards to data shapes and I am hoping for some insight. Specifically, at each time point in a sequence the number of context points can vary and can also be null. I am trying using masks but that is turning out to be somewhat messy. Would you have suggestions regarding how this could be achieved?

dirmeier commented 1 month ago

Hi @ojss, that sounds fantastic! Having this in ramsey would be amazing.

With respect to Jax, I agree that this is not straight-forward and a bit complicated. We also used masks for this in the end for some models (which are not yet implemented). But I think you could just start with a prototype that implements a working model and then we'll see if one can speed it up, etc.

Cheers, Simon