google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.14k stars 646 forks source link

Inconsistent results when module is a property of another #3956

Closed epignatelli closed 5 months ago

epignatelli commented 5 months ago

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

What's the issue

When the exact same module is embedded as part of the properties of another, results are incorrect.

For example, in the following, I expectFoo and Bar to return the same values, but they don't.

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    encoder: nn.Module
    head_size: int

    @nn.compact
    def __call__(self, x):
        return nn.Dense(self.head_size)(self.encoder(x))

class Bar(nn.Module):
    head_size: int

    @nn.compact
    def __call__(self, x):
        network = nn.Sequential([
            nn.Dense(32),
            jax.nn.tanh,
            nn.Dense(self.head_size)
        ])

        return network(x)

What you expected to happen:

Give the same results when encoder is:

encoder = nn.Sequential([
    nn.Dense(32),
    jax.nn.tanh,
])

Steps to reproduce:

https://colab.research.google.com/drive/1NdBGB7ue1V-11ebJr8H4K3HFTh6OfJnM?usp=sharing

cgarciae commented 5 months ago

Hey! An noted in the RNG guide, the random keys generated depend on the Module path. Because you get different paths for the Modules using the two variants that you showed here it is expected that they have different initial weights.

epignatelli commented 5 months ago

Ah, sorry -- I didn't look at it, it makes very much sense. Thanks for the quick answer!