I am trying to add both a Flax.linen.Dense as well as Stax.Dense to a model(have to do this for something else). However the stax paramters do no show up in the model parameters. I have tried in vain to modify the params, but it keeps giving me error ValueError('FrozenDict is immutable.'). Any help is appreciated. Code is below
from jax.example_libraries import stax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from optax import adam
from flax.core.frozen_dict import unfreeze
class StaxLayerModule(nn.Module):
def setup(self):
Initialize the stax layer and get initial parameters
_, self.stax_params = init_fn(jax.random.PRNGKey(0), (5,))
self.Dense1 = nn.Dense(5,kernel_init=nn.initializers.normal(.5))
self.nn_params = self.Dense1.init(jax.random.PRNGKey(0), jnp.ones((10,)))
self.nn_params = (unfreeze(dict(self.nn_params)))
for key, value in stx_param_dict.items():
self.nn_params[key + '_stax'] = value
# Convert back to FrozenDict before using it as params
self.nn_params = nn.freeze(self.nn_params)
def __call__(self, inputs):
x=self.apply_fn(self.stax_params, x)
return x
Example usage
model = StaxLayerModule()
input_data = jnp.full((3, 10),3) # Example input data
The code is a bit unreadable in its current format, but from the error it seems solvable by passing mutable=True in your apply() calls. See more in the documentation of apply.
I am trying to add both a Flax.linen.Dense as well as Stax.Dense to a model(have to do this for something else). However the stax paramters do no show up in the model parameters. I have tried in vain to modify the params, but it keeps giving me error ValueError('FrozenDict is immutable.'). Any help is appreciated. Code is below from jax.example_libraries import stax import flax.linen as nn import jax import jax.numpy as jnp import optax from optax import adam from flax.core.frozen_dict import unfreeze class StaxLayerModule(nn.Module): def setup(self):
Initialize the stax layer and get initial parameters
Example usage
model = StaxLayerModule() input_data = jnp.full((3, 10),3) # Example input data
print("Input",input_data) ground_truth=jnp.zeros((3,5)) optimizer = adam(learning_rate= .01) rng = jax.random.PRNGKey(0) params=model.init(rng, input_data)
opt_state = optimizer.init(modified_params)
@jax.jit def mse(params, x, y):
Define the squared loss for a single pair (x,y)
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit def update_params(grads, params, opt_state): updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state
@jax.jit def train_step(params, opt_state, model_input, ground_truth):
for step in range(11):
output_data = model.apply(model.init(jax.random.PRNGKey(0), input_data), input_data)