google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.17k stars 650 forks source link

modifying params of flax.linen. Module model #3730

Open lopa23 opened 9 months ago

lopa23 commented 9 months ago

I am trying to add both a Flax.linen.Dense as well as Stax.Dense to a model(have to do this for something else). However the stax paramters do no show up in the model parameters. I have tried in vain to modify the params, but it keeps giving me error ValueError('FrozenDict is immutable.'). Any help is appreciated. Code is below from jax.example_libraries import stax import flax.linen as nn import jax import jax.numpy as jnp import optax from optax import adam from flax.core.frozen_dict import unfreeze class StaxLayerModule(nn.Module): def setup(self):

Initialize the stax layer and get initial parameters

    init_fn,self.apply_fn=stax.Dense(5)
    _, self.stax_params = init_fn(jax.random.PRNGKey(0), (5,))

    W=self.stax_params[0]
    b=self.stax_params[1]
    self.Dense1 = nn.Dense(5,kernel_init=nn.initializers.normal(.5))
    self.nn_params = self.Dense1.init(jax.random.PRNGKey(0), jnp.ones((10,)))
    self.nn_params = (unfreeze(dict(self.nn_params)))
    stx_param_dict=dict()
    stx_param_dict['kernel']=W
    stx_param_dict['bias']=b

    for key, value in stx_param_dict.items():
        self.nn_params[key + '_stax'] = value

    # Convert back to FrozenDict before using it as params
    self.nn_params = nn.freeze(self.nn_params)

def __call__(self, inputs):

    x=self.Dense1(inputs)

    x=self.apply_fn(self.stax_params, x)

    return x

Example usage

model = StaxLayerModule() input_data = jnp.full((3, 10),3) # Example input data

print("Input",input_data) ground_truth=jnp.zeros((3,5)) optimizer = adam(learning_rate= .01) rng = jax.random.PRNGKey(0) params=model.init(rng, input_data)

opt_state = optimizer.init(modified_params)

@jax.jit def mse(params, x, y):

Define the squared loss for a single pair (x,y)

pred = model.apply(params, x)
return jnp.mean((ground_truth-pred)**2)

loss_grad_fn = jax.value_and_grad(mse)

@jax.jit def update_params(grads, params, opt_state): updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state

@jax.jit def train_step(params, opt_state, model_input, ground_truth):

loss_val, grads = loss_grad_fn(params, model_input, ground_truth)
params, opt_state = update_params(grads, params, opt_state)
return params, opt_state

for step in range(11):

    params, opt_state=train_step(params,  opt_state, input_data, ground_truth)

    model_output =model.apply(params, input_data)

    loss_val=jnp.mean((ground_truth-model_output)**2)
    print(f'Loss step {step}: ', loss_val)
    if step % 10 == 0:
        print(model_output)
        if (loss_val < 0.0005):#.0005 for 256
            break

output_data = model.apply(model.init(jax.random.PRNGKey(0), input_data), input_data)

print(output_data)

IvyZX commented 9 months ago

The code is a bit unreadable in its current format, but from the error it seems solvable by passing mutable=True in your apply() calls. See more in the documentation of apply.