patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.02k stars 135 forks source link

best way to write a wrapper function for tree_at #784

Open sede-fa opened 1 month ago

sede-fa commented 1 month ago

Hi,

I'm trying to rewrite my process flowsheet simulator in equinox. I'm able to solve individual components fine by creating my classes as follows:

class Unit1(eqx.Module):
    F0: interfaces.Flange
    F1: interfaces.Flange
    [other params]

    def __init__(self, F0, [other params]):
        self.F0 = F0
        self.F1 = F0  # just for initialising

    def forward(self):
        Eq = jnp.array([
                         self.F0.n - self.F1.n,
                          {other balance equations here}
                ])
        return Eq

    @eqx.filter_jit
    def f_sol(self, x):
        self = eqx.tree_at(lambda m: m, self, self.update(x))
        return self.forward()

    def update(self, x):
        self = eqx.tree_at(lambda m: m.F1.n, x[0])
        {update remaining fields of Flange F1}

    def run_sol(self):
        {run solver here with initial guesses etc.}
        self = eqx.tree_at(lambda m: m, self, self.update(solver.x))
        return self

However, when I want to connect components and update them sequentially I'm only doing this through repeated use of eqx.tree_at as follows:

class SystemModel(eqx.Module):
    unit1: Unit1
    unit2: Unit2

    def __init__(self, [params]):
        F0 = interfaces.Flange(n=1, ...)
        self.unit1 = Unit1(F0, ...)
        self.unit2 = Unit2(self.unit1.F1, ...)

    def forward(self, F0):
         self = eqx.tree_at(lambda m: m.unit1, self, F0)   # make connection to source
         self = eqx.tree_at(lambda m: m.unit1, self, self.unit1.run_sol())  # run solver to update outlet flange in unit1
         self = eqx.tree_at(lambda m: m.unit2, self, self.unit1.F0)  # sequentially update inlet of unit2
         self = eqx.tree_at(lambda m: m.unit2, self, self.unit2.run_sol())  # run solver to update unit2

        return self

and this can be executed as follows:

sysModel = SystemModel(...)
F0 = interfaces.Flange(n=1, ...)

sysModel = sysModel.forward(F0)

This all executes fine, however I'm wondering if it may be possible to do this through a wrapper method like connect(Flange1, Flange2) that could handle the connections and also allow me to add some additional features such as graph generation using mermaid which I have in the numpy version of my framework.

I know this is out of scope of equinox and a long shot, but any help would be much appreciated :)

patrick-kidger commented 1 month ago

Is there any reason you need to modify them after-the-fact? Can this all just happen at initialization?

At any rate the construction you have here looks a little odd. Assuming everything is functionally pure (I sure hope so!) then code like

self = eqx.tree_at(lambda m: m.unit1, self, F0)
self = eqx.tree_at(lambda m: m.unit1, self, self.unit1.run_sol()

should be equivalent to

self = eqx.tree_at(lambda m: m.unit1, self, F0.run_sol()