Open Arun-Niranjan opened 3 years ago
I looked into Awkward Array and watched that tutorial video. My first impression is that this would be a good direction to explore. The ability to go between semantic-style labels (the 'dot' notation e.g. my_array.field1
) and just array-style numpy slicing (my_array[:,...,some_fancy_index]
) is quite attractive, in addition to their extended numpy methods for broadcasting and reduction. I think this could be particularly useful when we need to apply an identical set of mathematical transformations across all the posterior marginals of a given multi-hidden-state-factor belief, or when we have to Bayesian model average across policies, within each (differently-dimensioned, policy-conditioned) hidden state factor.
So the TL;DR is yes, I think this is worth investigating further, from my end you could go ahead with a PR - let's hear what @alec-tschantz also has to say.
Great, I'll have a go at this either this weekend or next, depending on how work goes!
Also apologies for not making a PR for this yet - it is on my todo list for this week!
I thought it'd be good to get a discussion going on how we represent distributions in pymdp.
Currently we have the categorical and dirichlet classes, which I believe we are trying to move away from as they're tricky to use and it means all of the algorithms which call them have to constantly check if their inputs are one of these classes or not.
It seems for now we're moving towards using numpy arrays of numpy arrays to shift data around. IIUC This is necessary because we often have varying sub structures depending on the number of factors involved (e.g. we have multi-dimensional tensors representing an agents priors, policy choices etc.)
I previously considered whether we should write our own dataclass for representing what we need, but I think it's likely to end up the same way as the Dirichlet/Categorical classes. It would be a step forward, but still vulnerable to the same problems of being fragile to refactor and improve upon.
I suggest we use the third party Awkward Array to solve this problem for us. They seem to have a stable API, and we should be able to use any existing numpy methods with those arrays as long as they are not jagged (which we have to sort out anyway in the current implementation).
I think using this library will provide the following benefits:
There is always a risk of introducing third party dependencies, but I am of the opinion that if it's good enough for particle physics (tracks and decay events) at the scale CERN is dealing with, it is almost certainly good enough for us.
@conorheins @alec-tschantz what do you think? I could have a go at introducing it in a PR if you think this idea is worth investigating further
(tutorial video here is nice: https://www.youtube.com/watch?v=WlnUF3LRBj4)