Open femtomc opened 1 year ago
Hey! Oryx is being maintained to work against JAX at HEAD but I'm not working on any new features (I work full time on Jax and Jax triton now).
@sharadmv Thanks for the info!
Just to be transparent about motivations, I'm building a system based on Gen which uses JAX.
From Oryx's core, I'd like to use the inverse
and ildj
transforms to support an "exact logpdf" language, whose model objects act like distributions (a bit more about Gen: there is no single modeling language - there's a collection of languages for models, whose objects implement an abstract interface - distributions are one such object, and I'm considering a language for "distributions + ILDJ compat functions" as another such implementor of the interface).
The implementation of these transformations in Oryx seem well designed for this task -- so I'd like to use Oryx (or, at the very least, the conceptual content of Oryx) for the task.
I've also been considering the maintenance/dev model of Oryx -- as I was considering Oryx as a dependency. Depending on information about the maintenance of Oryx:
inverse
and ildj
-- I've used the propagate
interpreter for another language previously, so I already have a modified version of that in my codebase -- but I really wanted to wait to do any forking/mangling until I could converse with you. There's parts of the library which I would likely make great use of, and other parts which I don't necessarily think I would use (e.g. I think I wouldn't actually use any of the inference modules, but potentially the nn
and optimizer
module might be useful, if I'm understanding their value prop - see below).A few other things I've been thinking about in Oryx proper:
nn
module -- and the intent behind it. One fascinating value proposition would be supporting nn
parametrized functions which operate on random variables -- which are also compat with ILDJ
. I'm sort of guessing that's what was intended. In Gen's Julia implementation, we constructed an "invertible transformation" distributions DSL - but it's less expressive than Oryx - and I don't think we seriously considered neural networks + ILDJ.Cool work! I will admit, when I first looked at Oryx - I totally misjudged the conceptual content - only recently did I really appreciate the language value proposition and design.
Thanks for any comments.
Just to be transparent about motivations, I'm building a system based on Gen which uses JAX.
Sounds awesome! Do you have a repo/doc I could read to learn more?
but potentially the nn and optimizer module might be useful, if I'm understanding their value prop - see below
The idea of parameterized, invertible functions is pretty core to Oryx (I have some examples internally of using Oryx nn
to implement RealNVP and MAF). However, the nn
library is a bit more opinionated than it needs to be. I'd recommend using harvest
directly to build your own mini state-management library.
Here's an example one:
# Mini state library
collect = functools.partial(oryx.core.reap, tag=oryx.core.state.VARIABLE)
inject = functools.partial(oryx.core.plant, tag=oryx.core.state.VARIABLE)
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Module:
params: Any
apply: Any
def __call__(self, *args, **kwargs):
return self.apply(self.params, *args, **kwargs)
def tree_flatten(self):
return (self.params,), (self.apply,)
@classmethod
def tree_unflatten(self, data, xs):
return Module(xs[0], data[0])
@oryx.core.ppl.log_prob.register(Module)
def module_log_prob(module):
return ppl.log_prob(lambda *args: module.apply(module.params, *args))
def init(f):
def wrapped(state_key, *args):
params = collect(f)(state_key, *args)
return Module(params, inject(partial(f, state_key)))
return wrapped
The Module
's apply
could be an invertible function (like a normalizing flow) or an Oryx probabilistic program that we can use log_prob
with.
You would have a better understanding than me about this -- do you feel like the current state of Oryx is "ruleset complete" for ILDJ/inverses?
The rules are probably not as complete as I'd like them to be, namely for lack of time/demand. I'm happy to accept new rules in a PR though!
Control flow is a big hole in the rules right now -- inverting something like scan
is possible, but highly nontrivial. However, doing so would enable things like HMMs expressed via scan
! That was a long-term goal w/ Oryx for me but I have never gotten around to implementing it.
Sounds awesome! Do you have a repo/doc I could read to learn more?
@sharadmv I've sent you a private email about this (private, because we're still working closed source).
The idea of parameterized, invertible functions is pretty core to Oryx (I have some examples internally of using Oryx nn to implement RealNVP and MAF). However, the nn library is a bit more opinionated than it needs to be. I'd recommend using harvest directly to build your own mini state-management library.
harvest
is super neat. I wrote a restricted version of harvest
previously -- I wasn't concerned with handling higher-order primitives (partially, in my modeling code - I'm still not, because there's a level of model design modularity which allows me to use higher-order models to support things like vmap
or scan
, etc).
Control flow is a big hole in the rules right now -- inverting something like scan is possible, but highly nontrivial. However, doing so would enable things like HMMs expressed via scan! That was a long-term goal w/ Oryx for me but I have never gotten around to implementing it.
Right, this is pretty interesting. Because Gen doesn't assume any restrictions on the return value function $f$, there's a straightforward way to support things like logpdf
for internal random choices in models which use scan
(Gen doesn't assume that the return value function is a transformation whose output you wish to constrain). The way you gain access to scan
is to use one of these higher-order models above (which also implement Gen's interface).
I am curious what happens if I use an Oryx model without control flow, which supports logpdf
- and then shove it into one of the higher-order models above.
(re -- when I make comments about Gen + Oryx, I'm thinking of Oryx as providing a DSL for defining objects with sample
and exact logpdf
evaluation - but if an object supports these two interfaces, you can automatically define Gen's interface on it.)
Hi all!
Will
Oryx
continue to be actively maintained? Are there maintainers who are hoping to continue working on the package?