Open ToshiyukiBandai opened 6 months ago
I doubt it will work in any consistent manner. Here is a simple example where trying to transform with it fails:
import equinox as eqx
from jax import numpy as jnp
import jax
class State1(eqx.Module):
x: jnp.ndarray
class State2(eqx.Module):
y: jnp.ndarray
def merge(ob1, ob2):
ob1.__dict__.update(ob2.__dict__)
return ob1
state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2
def fn(state):
return jnp.squeeze(state.y + state.x)
jax.grad(fn)(state3)
[2.]
[3.]
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-7-04689d233c68>](https://localhost:8080/#) in <cell line: 24>()
22 return jnp.squeeze(state.y + state.x)
23
---> 24 jax.grad(fn)(state3)
[... skipping hidden 10 frame]
[<ipython-input-7-04689d233c68>](https://localhost:8080/#) in fn(state)
20
21 def fn(state):
---> 22 return jnp.squeeze(state.y + state.x)
23
24 jax.grad(fn)(state3)
AttributeError: 'State1' object has no attribute 'y'
If you goal is just for printing or something then maybe its fine, but if you want to then do jax computations with that object, I would use a different approach. If you know what State3 should look like a priori I would just define that class and initialize it from a state1 and state2 instance
That makes sense. I will initialize state3 using state1 and state2! Thank you!
Hi all,
I am building a multi-physics solver using JAX and equinox, where I want to merge multiple dataclasses inherited from eqx.Module. If I am using standard Python classes, the following
merge
function works:If I defined the state classes using eqx.Module, the
merge
function seems to work:Is this way of merging two dataclasses a good practice? Let me know if there are better ways to do that.