google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

Issue with Optimizer Update in A2C Network with Optax Body: #4391

Open Tomato-toast opened 2 days ago

Tomato-toast commented 2 days ago

Hello everyone,

I've encountered a problem while implementing an A2C (Advantage Actor-Critic) network involving Flax and Optax. My network includes _policynetwork and _valuenetwork, each containing _policyhead and torso. When attempting to use optimizer.update(grad), I received the following error:

_ValueError: Mismatch custom node data: ('policy_head', 'torso') != ('policy_network', 'valuenetwork');

The error message indicates that the expected keys are (_'policy_network', 'valuenetwork'), but the actual provided keys are ('_policyhead', 'torso'). The structure of my model parameters is as follows:

_State({ 'policy_network': { 'policy_head': {...}, 'torso': {...}, }, 'value_network': { 'policyhead': {...}, 'torso': {...}, }) I have tried to combine the model parameters and pass them to the optimizer, like this:

_params = {'w1': model1_params, 'w2': model2params}

However, this approach did not resolve the issue. I'm wondering if there is another way to correctly initialize and update the parameters of the A2C network's parameters using Optax in Flax.

If you have any suggestions or need more information, please let me know. Thank you very much for your help!

cgarciae commented 2 days ago

Hi @Tomato-toast, can you post some psuedo code of how you are constructing the Optimizer and gradients?

Tomato-toast commented 1 day ago

Hi @Tomato-toast, can you post some psuedo code of how you are constructing the Optimizer and gradients?

Below is a pseudo-code example of how the Optimizer and gradients are constructed and applied:

    class ConnectorTorso(nnx.Module):
        def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
            self.rngs = rngs
            self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
        def __call__(self, x):
            x = self.linear(x)
            return x

    def make_actor_network_connector(rngs = nnx.Rngs(0)):
        class PolicyNetwork(nnx.Module):
            def __init__(self, in_features, out_features):
                self.torso = ConnectorTorso(in_features, out_features, rngs=rngs)

            def __call__(self, x):
                return self.torso(x)

        return PolicyNetwork

    def make_critic_network_connector(rngs = nnx.Rngs(0)):
        class CriticNetwork(nnx.Module):
            def __init__(self, in_features, out_features):
                self.torso = ConnectorTorso(in_features, out_features, rngs=rngs)

            def __call__(self, x):
                return self.torso(x)

        return CriticNetwork

    class A2CAgent:
        def __init__(self):
            self.optimizer = nnx.Optimizer(
                optax.adam(learning_rate=1e-3)
            )
        # Initialize the policy network and value network parameters.
        def init_params(self, key: chex.PRNGKey) -> ParamsState:
            _, policy_params, _= nnx.split(self.actor_critic_networks.policy_network, nnx.Param, ...)
            _, critic_params, _ = nnx.split(self.actor_critic_networks.value_network, nnx.Param, ...)

            params = ActorCriticParams(
                actor = policy_params,
                critic = critic_params,
            )
            params_state = ParamsState(
                params=params,
                opt_state=opt_state,
                update_count=jnp.array(0, float),
            )
            return params_state

        def a2c_loss(self, policy_network, params, observations, actions, returns):
            # Calculate the outputs of the policy and value networks
            policy_output = policy_network(params.actor, observations)
            critic_output = policy_network(params.critic, observations)

            # Policy Loss: Based on the Advantage Function
            advantages = returns - critic_output
            policy_loss = -jnp.mean(jnp.log(policy_output) * advantages)

            # Value Loss: Mean Squared Error (MSE)
            critic_loss = jnp.mean((critic_output - returns) ** 2)

            # Entropy Loss: Encouraging Policy Exploration
            entropy_loss = -jnp.mean(policy_output * jnp.log(policy_output + 1e-8))

            # total loss
            return policy_loss + critic_loss - 0.01 * entropy_loss  # 熵损失系数为0.01

        # Execute a training epoch and update the parameters.
        def run_epoch(self, training_state: TrainingState) -> Tuple[TrainingState, Dict]:
            grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)
            ((policy_output, critic_output), metrics), grad = grad_fn(actor_critic_networks.policy_network, params, training_state.acting_state)
            grad = optax.clip_by_global_norm(grad, max_norm=1.0)
            updates, opt_state = self.optimizer.update(grad, training_state)
            params = optax.apply_updates(training_state.params_state.params, updates)
            training_state = TrainingState(
                params_state=ParamsState(
                    params=params,  
                    opt_state=opt_state,           
                    update_count=training_state.params_state.update_count + 1,
                ),
                acting_state=acting_state,  
            )
            return training_state, metrics

Thanks!

stergiosba commented 16 hours ago

I would just offer my input here and some suggestions based on my relatively short experience with NNX.

I noticed you are using the flax.Linen.TrainState and you also split the graphdef and parameters using nnx.split thus I am going to assume you need backwards compatibility with Linen. If that is not the case, you should be happy to know that in flax.nnx you don't have to do this anymore, at list for this simple example. With that being said, splitting the parameters and using trainstate is a perfectly fine working option (I don't like it personally, that's why I switched from linen to nnx recently).

Ok, to the matter at hand, the problem here is with the a2c_loss function definition and the way it's transformed with nnx.value_and_grad.

You have:

def a2c_loss(self, policy_network, params, observations, actions, returns):

When you transform with nnx.value_and_grad you are taking the derivative with respect to the argnums argument in the definition of the nnx.value_and_grad as seen here:

flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())

You see by default argnums=0 which you leave on default in your code as seen here:

grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)

Thus you are taking the derivative with respect to self in a2c_loss. Now you can change it to the following and tell us what you get:

grad_fn = nnx.value_and_grad(self.a2c_loss, argnums=2, has_aux=True)

This will take the derivative with respect to params.

On a more general note, you can simplify your code significantly. For example, ConnectorTorso and the two heads, actor and critic could be combined as this will probably lead to faster compilation if in the future your aim is to make a more complex model, just some food for thought.

cgarciae commented 10 hours ago

@Tomato-toast how are the policy_network functions implemented?

policy_output = policy_network(params.actor, observations)
critic_output = policy_network(params.critic, observations)

They seem to be Modules that take in their params which is a bit peculiar.

Since you are using a functional style training loop, I'd recommend to storing the graphdefs and using nnx.merge to reconstruct the Modules inside the loss function before calling them. Check out this examples/nnx_toy_examples/03_train_state.py that shows how to use NNX with TrainState.

Regarding the argnum situation, small correction to what @stergiosba pointed out, yes you should change the argnum to match the params position but because self is passed via a bound method it doesn't count so it should be argnums=1.