patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
100 stars 2 forks source link

quax boundary recursive wrapping of a pytree #13

Closed nstarman closed 6 months ago

nstarman commented 6 months ago

In https://github.com/patrick-kidger/quax/issues/10#issuecomment-1949501175_ we discussed a function wrap_arrayish_value_into_tracer_like to handle cases like

def funct(x, y):
    z = wrap_arrayish_value_into_tracer_like(Value(...), x)
    return x + y + z

In https://github.com/GalacticDynamics/galax/pull/187 I'm encountering the lack of this functionality in the _potential_energy methods, where the Parameters of the potential do not pass through the same quax boundary.

def _potential_energy(self, xyz, t, _G):
    m = self.some_param(t)
    return G * m / ...

In our case some_param is often a callable object with a .value attribute. In testing this self.some_param.value is not wrapped into a tracer when self is wrapped. Would recursively wrapping into the PyTree be able to fix this problem?

patrick-kidger commented 6 months ago

Hmm, that seems a bit surprising. Can you create a MWE? When crossing a quaxify boundary, then all Values, even inside PyTrees should be wrapped. So that something like this should work:

class Bar(eqx.Module):
    value: quax.Value

class Foo(eqx.Module):
    some_param: Bar

    @quaxquaxify
    def _potential_energy(self, ...):
        self.some_param.value  # this is wrapped
nstarman commented 6 months ago

I've solved the issue. I was incorrectly constructing an object, so it wasn't getting wrapped properly. A more careful construction solved the issue!