Closed nstarman closed 6 months ago
Hmm, that seems a bit surprising. Can you create a MWE? When crossing a quaxify
boundary, then all Value
s, even inside PyTrees should be wrapped. So that something like this should work:
class Bar(eqx.Module):
value: quax.Value
class Foo(eqx.Module):
some_param: Bar
@quaxquaxify
def _potential_energy(self, ...):
self.some_param.value # this is wrapped
I've solved the issue. I was incorrectly constructing an object, so it wasn't getting wrapped properly. A more careful construction solved the issue!
In https://github.com/patrick-kidger/quax/issues/10#issuecomment-1949501175_ we discussed a function
wrap_arrayish_value_into_tracer_like
to handle cases likeIn https://github.com/GalacticDynamics/galax/pull/187 I'm encountering the lack of this functionality in the
_potential_energy
methods, where the Parameters of the potential do not pass through the same quax boundary.In our case
some_param
is often a callable object with a.value
attribute. In testing thisself.some_param.value
is not wrapped into a tracer whenself
is wrapped. Would recursively wrapping into the PyTree be able to fix this problem?