patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

merging multiple eqx.Module classes #713

Open ToshiyukiBandai opened 6 months ago

ToshiyukiBandai commented 6 months ago

Hi all,

I am building a multi-physics solver using JAX and equinox, where I want to merge multiple dataclasses inherited from eqx.Module. If I am using standard Python classes, the following merge function works:

def merge(ob1, ob2):
    ob1.__dict__.update(ob2.__dict__)
    return ob1

class State1():
    def __init__(self, x):
        self.x = x

class State2():
    def __init__(self, y):
        self.y = y

state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2

If I defined the state classes using eqx.Module, the merge function seems to work:

class State1(eqx.Module):
    x: Float_1D

class State2(eqx.Module):
    y: Float_1D

state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2

Is this way of merging two dataclasses a good practice? Let me know if there are better ways to do that.

lockwo commented 6 months ago

I doubt it will work in any consistent manner. Here is a simple example where trying to transform with it fails:

import equinox as eqx 
from jax import numpy as jnp
import jax

class State1(eqx.Module):
    x: jnp.ndarray

class State2(eqx.Module):
    y: jnp.ndarray

def merge(ob1, ob2):
    ob1.__dict__.update(ob2.__dict__)
    return ob1

state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2

def fn(state):
  return jnp.squeeze(state.y + state.x)

jax.grad(fn)(state3)
[2.]
[3.]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-7-04689d233c68>](https://localhost:8080/#) in <cell line: 24>()
     22   return jnp.squeeze(state.y + state.x)
     23 
---> 24 jax.grad(fn)(state3)

    [... skipping hidden 10 frame]

[<ipython-input-7-04689d233c68>](https://localhost:8080/#) in fn(state)
     20 
     21 def fn(state):
---> 22   return jnp.squeeze(state.y + state.x)
     23 
     24 jax.grad(fn)(state3)

AttributeError: 'State1' object has no attribute 'y'

If you goal is just for printing or something then maybe its fine, but if you want to then do jax computations with that object, I would use a different approach. If you know what State3 should look like a priori I would just define that class and initialize it from a state1 and state2 instance

ToshiyukiBandai commented 6 months ago

That makes sense. I will initialize state3 using state1 and state2! Thank you!