patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 129 forks source link

NeuralODE implementation using JAX #525

Open Sudhanshu-G-Ogalapur opened 2 hours ago

Sudhanshu-G-Ogalapur commented 2 hours ago

I am building a NeuralODE framework using JAX to predict Hydrogen Oxidation. The model predicts the temperature rate and reaction rates for species mass fraction and the ode solver solves it to get species mass fraction and temperature. When i run the code, I am getting this value error. I have pasted the code below. ValueError: terms must be a PyTree of AbstractTerms (such as ODETerm), with structure <class 'diffrax._term.AbstractTerm'> I previously saw @lockwo solving this issue. If anyone of you or him could look at this, it would be of great help. Thank you!

Error

ValueError Traceback (most recent call last) Cell In[10], line 7 5 epoch_train_loss = 0 6 for y0, y_true in zip(train_y0, train_y_true): ----> 7 gards = jax.grad(loss)(params, y0, t_span, y_true) 8 updates, opt_state = optimizer.update(grads, opt_state) 9 params = optax.apply_updates(params, updates)

[... skipping hidden 10 frame]

Cell In[7], line 2, in loss(params, y0, t_span, y_true) 1 def loss(params, y0, t_span, y_true): ----> 2 y_pred = solve_ode(y0, params, t_span[0], t_span[1]) 3 return jnp.mean((y_pred - y_true) ** 2)

Cell In[6], line 12, in solve_ode(y0, params, t0, t1) 9 stepsize_controller = PIDController(rtol=1e-4, atol=1e-7) 11 # Solve the ODE using diffeqsolve ---> 12 solution = diffeqsolve( 13 term, 14 solver, 15 t0=t0, 16 t1=t1, 17 dt0=1e-6, 18 y0=y0, 19 saveat=SaveAt(ts=jnp.linspace(t0, t1, 100)), 20 adjoint=RecursiveCheckpointAdjoint(), 21 stepsize_controller=stepsize_controller 22 ) 23 return solution.ys

[... skipping hidden 15 frame]

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax_integrate.py:1028, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event) 1024 # Error checking 1025 if not _term_compatible( 1026 y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs 1027 ): -> 1028 raise ValueError( 1029 "terms must be a PyTree of AbstractTerms (such as ODETerm), with " 1030 f"structure {solver.term_structure}" 1031 ) 1033 if is_sde(terms): 1034 if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):

ValueError: terms must be a PyTree of AbstractTerms (such as ODETerm), with structure <class 'diffrax._term.AbstractTerm'>

Code:

import jax import jax.numpy as jnp import optax from diffrax import Kvaerno5, diffeqsolve, ODETerm, SaveAt, RecursiveCheckpointAdjoint, PIDController import haiku as hk import pandas as pd import matplotlib.pyplot as plt

train_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Training data\Training1.csv") val_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Validation data\Validation1.csv") test_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Testing data\Testing1.csv")

train_y0 = jnp.array(train_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values) # species concentration train_y_true = jnp.array(train_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values) # Reaction rates val_y0 = jnp.array(val_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration val_y_tue = jnp.array(val_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values)# Reaction rates test_y0 = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration test_y_true = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# Reaction rates

class NeuralODE(hk.Module):

def __init__(self, name='none'):
    super().__init__(name=name)

def __call__(self, x):
    x = hk.Linear(256)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(128)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(64)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(16)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(6)(x)
    return jax.nn.tanh(x)

def forward_fn(x):

network = NeuralODE()
return network(x)

network = hk.transform(forward_fn) def ode_rhs(t, y, params):

return network.apply(params, y)

t_span = (0, 0.4e-3)

def solve_ode(y0, params, t0=0, t1=0.4e-3):

term = ODETerm(lambda t, y, args: ode_rhs(t, y, params))
solver = Kvaerno5()
stepsize_controller = PIDController(rtol=1e-4, atol=1e-7)
solution = diffeqsolve(
    term,
    solver,
    t0=t0,
    t1=t1,
    dt0=1e-6,
    y0=y0,
    saveat=SaveAt(ts=jnp.linspace(t0, t1, 100)),
    adjoint=RecursiveCheckpointAdjoint(),
    stepsize_controller=stepsize_controller
)
return solution.ys

def loss(params, y0, t_span, y_true):

y_pred = solve_ode(y0, params, t_span[0], t_span[1])
return jnp.mean((y_pred - y_true) ** 2)

def init_network():

rng = jax.random.PRNGKey(0)
dummy_input = jnp.ones(6)  # 6 (species mass fractions + temperature)
params = network.init(rng, dummy_input)  # Initialize parameters
return params

optimizer = optax.adam(learning_rate = 0.0003) params = init_network() opt_state = optimizer.init(params)

epochs = 500 training_loss = [] val_loss = []

for epoch in range(epochs):

epoch_train_loss = 0
for y0, y_true in zip(train_y0, train_y_true):
    gards = jax.grad(loss)(params, y0, t_span, y_true)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    epoch_train_loss += loss(params, y0, t_span, y_true)

training_loss.append(epoch_train_loss/len(train_y0))
epoch_val_loss = jnp.mean([loss(params, y0, t_span, y_true) for y0, y_true in zip(val_y0, val_y_true)])
val_loss.append(epoch_val_loss/len(val_y0))
print(f"Epoch {epoch + 1}, Training_loss: {training_loss[-1]}, Validation_loss: {val_loss[-1]}")

Training1.csv Validation1.csv Testing1.csv

lockwo commented 2 hours ago
  1. format code as one block by wrapping it in "```" see raw of
print("Hello World")
  1. Provide MVC (https://stackoverflow.com/help/minimal-reproducible-example), I don't have those files and thus can't run the code.
Sudhanshu-G-Ogalapur commented 2 hours ago

Hi @lockwo I have pasted the training, validation and testing files below. and entire code in one block Testing1.csv Training1.csv Validation1.csv I am coding in Jupyter notebook. Please check. Thank you.

Code

import jax
import jax.numpy as jnp
import optax
from diffrax import Kvaerno5, diffeqsolve, ODETerm, SaveAt, RecursiveCheckpointAdjoint, PIDController
import haiku as hk
import pandas as pd
import matplotlib.pyplot as plt

train_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Training data\Training1.csv")
val_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Validation data\Validation1.csv")
test_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Testing data\Testing1.csv")

train_y0 = jnp.array(train_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values) # species concentration
train_y_true = jnp.array(train_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values) # Reaction rates
val_y0 = jnp.array(val_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration
val_y_tue = jnp.array(val_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values)# Reaction rates
test_y0 = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration
test_y_true = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# Reaction rates

class NeuralODE(hk.Module):
    def __init__(self, name='none'):
        super().__init__(name=name)

    def __call__(self, x):
        x = hk.Linear(256)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(128)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(16)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(6)(x)
        return jax.nn.tanh(x)

def forward_fn(x):
    network = NeuralODE()
    return network(x)

# Transfering the netwoek to be usable in jax
network = hk.transform(forward_fn)
def ode_rhs(t, y, params):
    return network.apply(params, y)

t_span = (0, 0.4e-3)
# Define the solve_ode function 
def solve_ode(y0, params, t0=0, t1=0.4e-3):
    # Define the ODE term based on ode_rhs
    term = ODETerm(lambda t, y, args: ode_rhs(t, y, params))
    solver = Kvaerno5()
    stepsize_controller = PIDController(rtol=1e-4, atol=1e-7)
    solution = diffeqsolve(
        term,
        solver,
        t0=t0,
        t1=t1,
        dt0=1e-6,
        y0=y0,
        saveat=SaveAt(ts=jnp.linspace(t0, t1, 100)),
        adjoint=RecursiveCheckpointAdjoint(),
        stepsize_controller=stepsize_controller
    )
    return solution.ys

def loss(params, y0, t_span, y_true):
    y_pred = solve_ode(y0, params, t_span[0], t_span[1])
    return jnp.mean((y_pred - y_true) ** 2)

def init_network():
    rng = jax.random.PRNGKey(0)
    dummy_input = jnp.ones(6)  # 6 (species mass fractions + temperature)
    params = network.init(rng, dummy_input)  # Initialize parameters
    return params

optimizer = optax.adam(learning_rate = 0.0003)
params = init_network()
opt_state = optimizer.init(params)

epochs = 500
training_loss = []
val_loss = []

for epoch in range(epochs):
    epoch_train_loss = 0

    for y0, y_true in zip(train_y0, train_y_true):
        gards = jax.grad(loss)(params, y0, t_span, y_true)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        epoch_train_loss += loss(params, y0, t_span, y_true)

    training_loss.append(epoch_train_loss/len(train_y0))
    epoch_val_loss = jnp.mean([loss(params, y0, t_span, y_true) for y0, y_true in zip(val_y0, val_y_true)])
    val_loss.append(epoch_val_loss/len(val_y0))
    print(f"Epoch {epoch + 1}, Training_loss: {training_loss[-1]}, Validation_loss: {val_loss[-1]}")