patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 133 forks source link

Using a Pytree state as a state #460

Closed ASKabalan closed 4 months ago

ASKabalan commented 4 months ago

I am trying to use a Pytree state as a state

But I get one state (not a list) of ys

MWE :


import jax
import jax.numpy as jnp
from dataclasses import dataclass
from functools import partial
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt, PIDController
from jax import tree_util
# Define and register the MyStruct class
@partial(jax.tree_util.register_dataclass, data_fields=['x', 'y'], meta_fields=['op'])
@dataclass
class MyStruct:
    x: jax.Array
    y: jax.Array
    op: str

# Define the ODE function for MyStruct
def ode_fn_struct(t, state, args):
    x = state.x
    y = state.y
    dx_dt = -x + y
    dy_dt = x - y
    dstate_dt = MyStruct(x=dx_dt, y=dy_dt, op=state.op)
    return dstate_dt

# Initialize the state as an instance of MyStruct
initial_state_struct = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')

# Define the ODE function for jnp.stack
def ode_fn_stack(t, state, args):
    x = state[:, 0]
    y = state[:, 1]
    dx_dt = -x + y
    dy_dt = x - y
    dstate_dt = jnp.stack([dx_dt, dy_dt], axis=-1)
    return dstate_dt

# Initialize the state as a stacked array
x_init = jnp.ones(3)
y_init = jnp.arange(3)
initial_state_stack = jnp.stack([x_init, y_init], axis=-1)

# Common parameters for both solutions
t0 = 0.0
t1 = 10.0
saveat = SaveAt(ts=[0.0, 1.0, 2.0, 5.0, 10.0])
solver = Tsit5()
step_size = PIDController(rtol=1e-5, atol=1e-6)

# Solve the ODE using MyStruct
ode_term_struct = ODETerm(ode_fn_struct)
solution_struct = diffeqsolve(
    ode_term_struct, solver, t0=t0, t1=t1, dt0=0.1, stepsize_controller=step_size,
    y0=initial_state_struct, saveat=saveat
)

# Solve the ODE using jnp.stack
ode_term_stack = ODETerm(ode_fn_stack)
solution_stack = diffeqsolve(
    ode_term_stack, solver, t0=t0, t1=t1, dt0=0.1, stepsize_controller=step_size,
    y0=initial_state_stack, saveat=saveat
)

# Print the solutions
print("Using MyStruct:")
print("Time points:", solution_struct.ts)
print("Time solution_struct X:", solution_struct.ys.x)
print("Time solution_struct Y:", solution_struct.ys.y)

# Unpack solution_struct.ys.y and ys.x to get the individual states
print("\nUsing jnp.stack:")
print("Time points:", solution_stack.ts)
for i, state in enumerate(solution_stack.ys):
    print(f"State at t={solution_stack.ts[i]}: x={state[:, 0]}, y={state[:, 1]}")

print(f"Time points: {solution_struct.ts}")
print(f"Shape of solution_struct.ys.x: {solution_struct.ys.x.shape}")
for i, state in enumerate(solution_struct.ys):
    print(f"State at t={solution_struct.ts[i]}: x={state.x}, y={state.y}, op={state.op}")

What would be a best practice way to unpack the states to a list of states

lockwo commented 4 months ago

As a general statement, equinox/jax likes to operate on pytrees of arrays rather than arrays/lists of pytrees (with some exceptions) so many things will return 1 pytree object with each member variable being "stacked" rather than a "stack" of pytrees. So if your pytree is now containing stacked member variables, practically you can just extract things as you normally would from a pytree. So if you wanted a list of individual states you could just do

l = [jax.tree_map(lambda x: x[i], solution_struct.ys) for i in range(len(solution_struct.ts))]

or something to that effect.

ASKabalan commented 4 months ago

Thank you, this works.