Open Sudhanshu-G-Ogalapur opened 2 hours ago
print("Hello World")
Hi @lockwo I have pasted the training, validation and testing files below. and entire code in one block Testing1.csv Training1.csv Validation1.csv I am coding in Jupyter notebook. Please check. Thank you.
import jax
import jax.numpy as jnp
import optax
from diffrax import Kvaerno5, diffeqsolve, ODETerm, SaveAt, RecursiveCheckpointAdjoint, PIDController
import haiku as hk
import pandas as pd
import matplotlib.pyplot as plt
train_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Training data\Training1.csv")
val_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Validation data\Validation1.csv")
test_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Testing data\Testing1.csv")
train_y0 = jnp.array(train_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values) # species concentration
train_y_true = jnp.array(train_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values) # Reaction rates
val_y0 = jnp.array(val_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration
val_y_tue = jnp.array(val_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values)# Reaction rates
test_y0 = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration
test_y_true = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# Reaction rates
class NeuralODE(hk.Module):
def __init__(self, name='none'):
super().__init__(name=name)
def __call__(self, x):
x = hk.Linear(256)(x)
x = jax.nn.relu(x)
x = hk.Linear(128)(x)
x = jax.nn.relu(x)
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
x = hk.Linear(16)(x)
x = jax.nn.relu(x)
x = hk.Linear(6)(x)
return jax.nn.tanh(x)
def forward_fn(x):
network = NeuralODE()
return network(x)
# Transfering the netwoek to be usable in jax
network = hk.transform(forward_fn)
def ode_rhs(t, y, params):
return network.apply(params, y)
t_span = (0, 0.4e-3)
# Define the solve_ode function
def solve_ode(y0, params, t0=0, t1=0.4e-3):
# Define the ODE term based on ode_rhs
term = ODETerm(lambda t, y, args: ode_rhs(t, y, params))
solver = Kvaerno5()
stepsize_controller = PIDController(rtol=1e-4, atol=1e-7)
solution = diffeqsolve(
term,
solver,
t0=t0,
t1=t1,
dt0=1e-6,
y0=y0,
saveat=SaveAt(ts=jnp.linspace(t0, t1, 100)),
adjoint=RecursiveCheckpointAdjoint(),
stepsize_controller=stepsize_controller
)
return solution.ys
def loss(params, y0, t_span, y_true):
y_pred = solve_ode(y0, params, t_span[0], t_span[1])
return jnp.mean((y_pred - y_true) ** 2)
def init_network():
rng = jax.random.PRNGKey(0)
dummy_input = jnp.ones(6) # 6 (species mass fractions + temperature)
params = network.init(rng, dummy_input) # Initialize parameters
return params
optimizer = optax.adam(learning_rate = 0.0003)
params = init_network()
opt_state = optimizer.init(params)
epochs = 500
training_loss = []
val_loss = []
for epoch in range(epochs):
epoch_train_loss = 0
for y0, y_true in zip(train_y0, train_y_true):
gards = jax.grad(loss)(params, y0, t_span, y_true)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
epoch_train_loss += loss(params, y0, t_span, y_true)
training_loss.append(epoch_train_loss/len(train_y0))
epoch_val_loss = jnp.mean([loss(params, y0, t_span, y_true) for y0, y_true in zip(val_y0, val_y_true)])
val_loss.append(epoch_val_loss/len(val_y0))
print(f"Epoch {epoch + 1}, Training_loss: {training_loss[-1]}, Validation_loss: {val_loss[-1]}")
I am building a NeuralODE framework using JAX to predict Hydrogen Oxidation. The model predicts the temperature rate and reaction rates for species mass fraction and the ode solver solves it to get species mass fraction and temperature. When i run the code, I am getting this value error. I have pasted the code below. ValueError:
terms
must be a PyTree ofAbstractTerms
(such asODETerm
), with structure <class 'diffrax._term.AbstractTerm'> I previously saw @lockwo solving this issue. If anyone of you or him could look at this, it would be of great help. Thank you!Error
ValueError Traceback (most recent call last) Cell In[10], line 7 5 epoch_train_loss = 0 6 for y0, y_true in zip(train_y0, train_y_true): ----> 7 gards = jax.grad(loss)(params, y0, t_span, y_true) 8 updates, opt_state = optimizer.update(grads, opt_state) 9 params = optax.apply_updates(params, updates)
Cell In[7], line 2, in loss(params, y0, t_span, y_true) 1 def loss(params, y0, t_span, y_true): ----> 2 y_pred = solve_ode(y0, params, t_span[0], t_span[1]) 3 return jnp.mean((y_pred - y_true) ** 2)
Cell In[6], line 12, in solve_ode(y0, params, t0, t1) 9 stepsize_controller = PIDController(rtol=1e-4, atol=1e-7) 11 # Solve the ODE using diffeqsolve ---> 12 solution = diffeqsolve( 13 term, 14 solver, 15 t0=t0, 16 t1=t1, 17 dt0=1e-6, 18 y0=y0, 19 saveat=SaveAt(ts=jnp.linspace(t0, t1, 100)), 20 adjoint=RecursiveCheckpointAdjoint(), 21 stepsize_controller=stepsize_controller 22 ) 23 return solution.ys
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax_integrate.py:1028, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event) 1024 # Error checking 1025 if not _term_compatible( 1026 y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs 1027 ): -> 1028 raise ValueError( 1029 "
terms
must be a PyTree ofAbstractTerms
(such asODETerm
), with " 1030 f"structure {solver.term_structure}" 1031 ) 1033 if is_sde(terms): 1034 if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):ValueError:
terms
must be a PyTree ofAbstractTerms
(such asODETerm
), with structure <class 'diffrax._term.AbstractTerm'>Code:
import jax import jax.numpy as jnp import optax from diffrax import Kvaerno5, diffeqsolve, ODETerm, SaveAt, RecursiveCheckpointAdjoint, PIDController import haiku as hk import pandas as pd import matplotlib.pyplot as plt
train_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Training data\Training1.csv") val_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Validation data\Validation1.csv") test_data = pd.read_csv("C:\Sudhanshu ME23M209\Data Generation and Training\Testing data\Testing1.csv")
train_y0 = jnp.array(train_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values) # species concentration train_y_true = jnp.array(train_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values) # Reaction rates val_y0 = jnp.array(val_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration val_y_tue = jnp.array(val_data[['Temp rate','H2 RR','O2 RR','H2O RR','OH RR','O RR']].values)# Reaction rates test_y0 = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# species concentration test_y_true = jnp.array(test_data[['Temperature','Y H2','Y O2','Y H2O','Y OH','Y O']].values)# Reaction rates
class NeuralODE(hk.Module):
def forward_fn(x):
network = hk.transform(forward_fn) def ode_rhs(t, y, params):
t_span = (0, 0.4e-3)
def solve_ode(y0, params, t0=0, t1=0.4e-3):
def loss(params, y0, t_span, y_true):
def init_network():
optimizer = optax.adam(learning_rate = 0.0003) params = init_network() opt_state = optimizer.init(params)
epochs = 500 training_loss = [] val_loss = []
for epoch in range(epochs):
Training1.csv Validation1.csv Testing1.csv