Open ojss opened 1 month ago
Hi @ojss, that sounds fantastic! Having this in ramsey would be amazing.
With respect to Jax, I agree that this is not straight-forward and a bit complicated. We also used masks for this in the end for some models (which are not yet implemented). But I think you could just start with a prototype that implements a working model and then we'll see if one can speed it up, etc.
Cheers, Simon
Hi @dirmeier, thanks for making this amazing codebase! I am currently working on a sequential neural process implementation that can hopefully be added to ramsey at some point. But I am quite new flax/jax and running into some issues with regards to data shapes and I am hoping for some insight. Specifically, at each time point in a sequence the number of context points can vary and can also be null. I am trying using masks but that is turning out to be somewhat messy. Would you have suggestions regarding how this could be achieved?