google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.07k stars 642 forks source link

We broke the 1:1 correspondence with attribute names and variable dict names #2100

Open marcvanzee opened 2 years ago

marcvanzee commented 2 years ago

from @levskaya:

class Foo(nn.Module):
  def setup(self):
    self.foo = nn.Dense(1, name="bar")
    self.qup = self.param('baz', lambda k: jnp.zeros((1,)))
  def __call__(self, x):
    return self.foo(x) + self.qup
Foo().init(jax.random.PRNGKey(0), jnp.zeros((1,)))
FrozenDict({
    params: {
        baz: DeviceArray([0.], dtype=float32),
        bar: {
            kernel: DeviceArray([[-0.58376074]], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

The above used to break loudly, and it should!

Initial investigation by @avital:

Over the past 1.5 years (I ran a test against every single commit), we actually never had any commit where the following code raised an exception:

  def test_setattr_name_var_agreement_in_setup(self):
    class Foo(nn.Module):
      def setup(self):
        self.qup = self.param('baz', lambda k: 0)
      def __call__(self):
        pass

    Foo(parent=None).init(jax.random.PRNGKey(0))

But we did, in the part, disallow entirely the use of name= for submodules defined within setup, which would have disallowed setting the wrong name for a submodule in setup. We lost that guard with https://github.com/google/flax/pull/976/files

I don't think we ever had tests for the variable attribute correspondence. We do have tests that you can't define two variables with the same name in different collections but not that the name aligns with the attribute being assigned to.

Suggestion from @jheek:

I think you could do something like this to disallowed giving different names in setup.

def __setattr__(self, name, value):
  if any(name in variables[col] for in col):
   assert variables[col] is value, f"A variable named {name} already exist. We don't allow variables and fields to have overlapping names"
cgarciae commented 2 years ago

Looking into this :eyes:

levskaya commented 2 years ago

I'm honestly afraid that this cat is already out of the bag. Many users' models (and checkpoints!) now exploit the current freedom to set the name apart from python attribute name. If we tried forcing it at this point we'd probably piss a lot of people off.

cgarciae commented 2 years ago

Based on @jheek's initial proposal I came up with this logic:

def _is_valid_field_value(name, val, variables) -> bool:
  value_found = False

  for collection in variables.values():
      for field, existing_value in collection.items():
        if val is existing_value:
          value_found = True
          if name == field:
            return True

  return not value_found

It will look for val on all collections, if it is found and name matches its good, if its found but no name matches you get a runtime error. However there are easy ways to get into trouble:

class Foo(nn.Module):

  def setup(self):
    self.bar = self.param("bar", lambda key: jnp.array(1))
    self.baz = self.bar # error: value was found but no `baz` key exists

What to do?

Seems like there is no general way to solve the issue, however if users don't use "value types" (ints and floats) and avoid the pattern above it should be good.

marcvanzee commented 1 year ago

@cgarciae is this issue still active? I see @jheek reviewed your PR #2102 but then it went stale. Maybe you two can try to work together to get this in, or if it turns out to be unfeasible we close the issue?

cgarciae commented 1 year ago

@marcvanzee can you check the internal CL? Maybe it was breaking some internal tests which is why we stopped? The PR is also tricky because it has a ton of edge cases we cant cover.