jax-ml / oryx

Oryx is a library for probabilistic programming and deep learning built on top of Jax.
https://tensorflow.org/probability/oryx
Apache License 2.0
211 stars 9 forks source link

Any insight into the dev model for `oryx`? #38

Open femtomc opened 1 year ago

femtomc commented 1 year ago

Hi all!

Will Oryx continue to be actively maintained? Are there maintainers who are hoping to continue working on the package?

sharadmv commented 1 year ago

Hey! Oryx is being maintained to work against JAX at HEAD but I'm not working on any new features (I work full time on Jax and Jax triton now).

femtomc commented 1 year ago

@sharadmv Thanks for the info!

Just to be transparent about motivations, I'm building a system based on Gen which uses JAX.

From Oryx's core, I'd like to use the inverse and ildj transforms to support an "exact logpdf" language, whose model objects act like distributions (a bit more about Gen: there is no single modeling language - there's a collection of languages for models, whose objects implement an abstract interface - distributions are one such object, and I'm considering a language for "distributions + ILDJ compat functions" as another such implementor of the interface).

The implementation of these transformations in Oryx seem well designed for this task -- so I'd like to use Oryx (or, at the very least, the conceptual content of Oryx) for the task.

I've also been considering the maintenance/dev model of Oryx -- as I was considering Oryx as a dependency. Depending on information about the maintenance of Oryx:

A few other things I've been thinking about in Oryx proper:

Cool work! I will admit, when I first looked at Oryx - I totally misjudged the conceptual content - only recently did I really appreciate the language value proposition and design.

Thanks for any comments.

sharadmv commented 1 year ago

Just to be transparent about motivations, I'm building a system based on Gen which uses JAX.

Sounds awesome! Do you have a repo/doc I could read to learn more?

but potentially the nn and optimizer module might be useful, if I'm understanding their value prop - see below

The idea of parameterized, invertible functions is pretty core to Oryx (I have some examples internally of using Oryx nn to implement RealNVP and MAF). However, the nn library is a bit more opinionated than it needs to be. I'd recommend using harvest directly to build your own mini state-management library.

Here's an example one:

# Mini state library

collect = functools.partial(oryx.core.reap, tag=oryx.core.state.VARIABLE)
inject = functools.partial(oryx.core.plant, tag=oryx.core.state.VARIABLE)

@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Module:
  params: Any
  apply: Any

  def __call__(self, *args, **kwargs):
    return self.apply(self.params, *args, **kwargs)

  def tree_flatten(self):
    return (self.params,), (self.apply,)

  @classmethod
  def tree_unflatten(self, data, xs):
    return Module(xs[0], data[0])

@oryx.core.ppl.log_prob.register(Module)
def module_log_prob(module):
  return ppl.log_prob(lambda *args: module.apply(module.params, *args))

def init(f):
  def wrapped(state_key, *args):
    params = collect(f)(state_key, *args)
    return Module(params, inject(partial(f, state_key)))
  return wrapped

The Module's apply could be an invertible function (like a normalizing flow) or an Oryx probabilistic program that we can use log_prob with.

You would have a better understanding than me about this -- do you feel like the current state of Oryx is "ruleset complete" for ILDJ/inverses?

The rules are probably not as complete as I'd like them to be, namely for lack of time/demand. I'm happy to accept new rules in a PR though!

Control flow is a big hole in the rules right now -- inverting something like scan is possible, but highly nontrivial. However, doing so would enable things like HMMs expressed via scan! That was a long-term goal w/ Oryx for me but I have never gotten around to implementing it.

femtomc commented 1 year ago

Sounds awesome! Do you have a repo/doc I could read to learn more?

@sharadmv I've sent you a private email about this (private, because we're still working closed source).

The idea of parameterized, invertible functions is pretty core to Oryx (I have some examples internally of using Oryx nn to implement RealNVP and MAF). However, the nn library is a bit more opinionated than it needs to be. I'd recommend using harvest directly to build your own mini state-management library.

harvest is super neat. I wrote a restricted version of harvest previously -- I wasn't concerned with handling higher-order primitives (partially, in my modeling code - I'm still not, because there's a level of model design modularity which allows me to use higher-order models to support things like vmap or scan, etc).

Control flow is a big hole in the rules right now -- inverting something like scan is possible, but highly nontrivial. However, doing so would enable things like HMMs expressed via scan! That was a long-term goal w/ Oryx for me but I have never gotten around to implementing it.

Right, this is pretty interesting. Because Gen doesn't assume any restrictions on the return value function $f$, there's a straightforward way to support things like logpdf for internal random choices in models which use scan (Gen doesn't assume that the return value function is a transformation whose output you wish to constrain). The way you gain access to scan is to use one of these higher-order models above (which also implement Gen's interface).

I am curious what happens if I use an Oryx model without control flow, which supports logpdf - and then shove it into one of the higher-order models above.

(re -- when I make comments about Gen + Oryx, I'm thinking of Oryx as providing a DSL for defining objects with sample and exact logpdf evaluation - but if an object supports these two interfaces, you can automatically define Gen's interface on it.)