google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.06k stars 639 forks source link

Feature request: ability to apply stop gradient to some parameters #1857

Closed NeilGirdhar closed 2 years ago

NeilGirdhar commented 2 years ago

To motivate this feature request, I'll explain what I'm currently doing (without Flax), and the other solutions I've considered. Then I'll suggest some Flax solution.

Problem:

In the process of inferring one of my modules, I need to mask varying subsets of the weights with stop-gradient in a single function:

(click for long snippet)

```python def infer(encoding: EncodingElement, observation: PoolingMessage, prediction: PredictionMessage, rng: Generator, weights: FrozenVariableDict) -> TwoPassEncodingConfiguration: sampler_rng, code_rng = rng.split() # Create four copies of the weights: # * weights_sg has stop_gradient applied to all weights, and # * the other three have stop_gradient applied to different partitions of # the weights. weights_sg, weights_g, weights_c, weights_e = _stop_gradient_on_some_weights(weights) # Inference ------------------------------------------------------------------------------------ # This function uses weights_sg so this calculation won't poison the weight cotangents. # However, cotangents still propagate back to observation. code_message = encoding.code_message(observation, weights_sg) # GLN loss ------------------------------------------------------------------------------------- # The scan parameters depend on weights_g. encoding_parameters_g = SamplerParameters(observation, prediction, weights_g) # This use of stop_gradient prevents the cotangents from propagating back from the scan through # to the observation. initial_code_message = stop_gradient(code_message) # This class manages an iterated function (a scan) sampler = EncodingSampler(encoding) sampler_iterations = encoding.inference_parameters.sampler_iterations initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng) # This is an extremely computationally expensive scan. sampler_state, sampler_trajectory = sampler.sample_trajectory( encoding_parameters_g, initial_sampler_state, sampler_iterations, None) # We calculate a GLN loss, which can only affects the subset of weights in weights_g. gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss) / sampler_iterations) iterative_code_message = sampler_state.code_message # Code loss ------------------------------------------------------------------------------------ # The code loss trains the code and selection links to produce a code message that predicts the # code message that we inferred by iteration. # This is the same code_message function as above, but uses weights_c. c_code_message = encoding.code_message(observation, weights_c, rng=code_rng, use_code_signal_noise=True) # When this loss is minimized only the weights that are not marked stop-gradient in weights_c # are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient # to its outputs. code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence) - c_code_message.log_presence)) code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value) - c_code_message.code_value)) code_loss = code_presence_loss + code_value_loss # Snipped a lot of code here that uses weights_e and produces output primals. return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss) # Below is the code that uses Haiku to partition the weights and apply stop gradient to different # partitions. _module_classes = [{'gln'}, {'code_value', 'code_presence'}, {'explanation'}] def _module_predicate(module_name: str, name: str, value: Array) -> int: prefix = module_name.split('/')[0] for i, prefix_set in enumerate(_module_classes): if prefix in prefix_set: return i raise RuntimeError # I was using Haiku before, but I'll have to port this to Flax somehow. def _partition_by_module(weights: FrozenVariableDict) -> tuple[FrozenVariableDict, ...]: return hk.data_structures.partition_n(_module_predicate, # type: ignore[arg-type] weights, len(_module_classes)) def _stop_gradient_on_some_weights(weights: FrozenVariableDict) -> list[FrozenVariableDict]: weights_sg = stop_gradient(weights) weights_p = _partition_by_module(weights) weights_sg_p = _partition_by_module(weights_sg) return ([weights_sg] + [hk.data_structures.merge(weights_pi, *[weights_sg_pi for j, weights_sg_pi in enumerate(weights_sg_p) if i != j]) for i, weights_pi in enumerate(weights_p)]) ```

Non-solution:

I discussed this with @cgarciae and brainstormed a non-solution: I could try to put the "C", "G", and "E" weights into different "collections". And then run inference three times. This doesn't work because:

Possible Flax interface:

We came up with two Flax interfaces that might work.

I suggested some kind of context manager flax.linen.stop_gradient:

(click for long snippet)

```python def infer(encoding: EncodingElement, observation: PoolingMessage, prediction: PredictionMessage, rng: Generator, weights: FrozenVariableDict) -> TwoPassEncodingConfiguration: sampler_rng, code_rng = rng.split() # Inference ------------------------------------------------------------------------------------ # This function uses weights_sg so this calculation won't poison the weight cotangents. # However, cotangents still propagate back to observation. with nn.stop_gradient(lambda c: True): code_message = encoding.code_message(observation) # GLN loss ------------------------------------------------------------------------------------- encoding_parameters_g = SamplerParameters(observation, prediction) # This use of stop_gradient prevents the cotangents from propagating back from the scan through # to the observation. initial_code_message = stop_gradient(code_message) sampler = EncodingSampler(encoding) sampler_iterations = encoding.inference_parameters.sampler_iterations initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng) # The scan parameters depend on weights_g. with nn.stop_gradient(lambda c: c.name.starts_with('gln')): # This class manages an iterated function (a scan) # This is an extremely computationally expensive scan. sampler_state, sampler_trajectory = sampler.sample_trajectory( encoding_parameters_g, initial_sampler_state, sampler_iterations, None) # We calculate a GLN loss, which can only affects the subset of weights in weights_g. gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss) / sampler_iterations) iterative_code_message = sampler_state.code_message # Code loss ------------------------------------------------------------------------------------ # The code loss trains the code and selection links to produce a code message that predicts the # code message that we inferred by iteration. # This is the same code_message function as above, but uses weights_c. with nn.stop_gradient(lambda c: c.name.starts_with('code')): c_code_message = encoding.code_message(observation, rng=code_rng, use_code_signal_noise=True) # When this loss is minimized only the weights that are not marked stop-gradient in weights_c # are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient # to its outputs. code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence) - c_code_message.log_presence)) code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value) - c_code_message.code_value)) code_loss = code_presence_loss + code_value_loss # Snipped a lot of code here that uses weights_e and produces output primals. return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss) ```

Cristian suggested a lifting transformation like those found in flax.core.lift. I'm still learning how these work, so I can't yet sketch what this might look like.

Possible side benefits

Besides applying stop-gradient, this kind of system may be able to do other things with parameters such as:

Of course, that's beyond this feature request, but I mention these ideas as something to keep in mind when considering solutions.

Conclusion

Am I missing an easy solution to my problem? If not, I will need to solve this problem in order to use Flax since this use of stop-gradient is integral to my research. Thanks for reading!

jheek commented 2 years ago

So in the haiku version of the code you are solving this "outside" of Haiku by operating on the variables dict directly and making 4 copies. In Flax you could do something similair for example by using flax.travere_util.flatten_dict. If you want to do this inside a linen Module you could use nn.map_variables where the mapping is basically identity but with stop_gradient applied to some or all of the params (again you cam make your life easy here by using flatten_dict).

jheek commented 2 years ago

Here's an sketch of what that would look like:

from flax import traverse_util

 def selective_stop_grad(variables):
      flat_vars = traverse_util.flatten_dict(variables)
      new_vars = {k: lax.stop_gradient(v) if some_filter_fn(k) else v for k, v in flat_vars.items()}
      return traverse_util.unflatten_dict(new_vars)

class MySGModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    MySGSubModule = nn.map_variables(MySubModule, "params", selective_stop_grad, init=True)
    return MySGSubModule(...)(x)
cgarciae commented 2 years ago

@jheek the map_variables solution is great, I like it a lot!

A HOWTO about freezing parameters using this strategy would be great.

NeilGirdhar commented 2 years ago

@jheek

I've been trying to implement your solution, but I can't seem to get it working for me. Here's roughly what I have:

from __future__ import annotations

from collections.abc import Callable
from dataclasses import asdict
from typing import Any, Generic, TypeVar

import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey

T = TypeVar('T', bound=nn.Module)

class StopGradientModule(nn.Module, Generic[T]):
    filter_f: Callable[[tuple[str, ...]], bool]
    submodule_cls: Callable[..., T]

    def setup(self) -> None:
        self.submodule = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient)

    def __call__(self, module: T) -> T:
        return self.submodule(**asdict(module))

    def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
        flat_vars = traverse_util.flatten_dict(variables)  # type: ignore[no-untyped-call]
        new_vars = {k: stop_gradient(v)
                    if self.filter_f(k) else v
                    for k, v in flat_vars.items()}
        return traverse_util.unflatten_dict(new_vars)  # type: ignore[no-untyped-call]

class X(nn.Module):
    def setup(self) -> None:
        self.dense = nn.Dense(10)
        # stop_gradient_all is a copy of self whose parameters are identical, but whose parameter
        # cotangents are always zero.
        self.stop_gradient_all = StopGradientModule(lambda _: True, X)

    def f(self, x: Any) -> Any:
        return self.dense(x), self.stop_gradient_all(self).dense(x)

print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f))

gives

Traceback (most recent call last):
  File "/home/neil/src/cmm/a.py", line 44, in <module>
    print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f))
  File "/home/neil/src/cmm/a.py", line 41, in f
    return self.dense(x), self.stop_gradient_all(self).dense(x)
  File "/home/neil/src/cmm/a.py", line 37, in setup
    self.dense = nn.Dense(10)
ValueError: Duplicate use of scope name: "dense"

I realize that this is currently a recursive mess, and I'm exploring the simplest way of accomplishing what I'm trying to accomplish.

NeilGirdhar commented 2 years ago

I'm still trying to get this working. Here's what I have now:

from __future__ import annotations

from collections.abc import Callable
from typing import Any, Generic, TypeVar

import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey
from tjax import print_generic

T = TypeVar('T', bound=nn.Module)

class StopGradientModule(nn.Module, Generic[T]):
    filter_f: Callable[[tuple[str, ...]], bool]
    submodule_cls: Callable[..., T]

    def setup(self) -> None:
        mapped_cls = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient,
                                      methods=['f'])
        self.submodule = mapped_cls()

    def f(self, x: Any) -> Any:
        print("Calling")
        return self.submodule.f(x)

    def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
        flat_vars = traverse_util.flatten_dict(variables)  # type: ignore[no-untyped-call]
        new_vars = {k: stop_gradient(v)
                    if self.filter_f(k) else v
                    for k, v in flat_vars.items()}
        return traverse_util.unflatten_dict(new_vars)  # type: ignore[no-untyped-call]

    def __call__(self):
        assert False

class X(nn.Module):
    def setup(self) -> None:
        self.dense = nn.Dense(3)

    def f(self, x: Any) -> Any:
        return self.dense(x)

    def __call__(self):
        assert False

class Y(nn.Module):
    def setup(self) -> None:
        self.x = X()
        # stop_gradient_all is a copy of x whose parameters are identical, but whose parameter
        # cotangents are always zero.
        self.stop_gradient_all = StopGradientModule(lambda _: True, X)

    def f(self, x: Any) -> Any:
        y = self.x.f(x)
        return y, self.stop_gradient_all.f(x)

    def __call__(self):
        assert False

(y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
print(y, y_prime)
print_generic(variables)

gives

Traceback (most recent call last):
  File "/home/neil/src/cmm/a.py", line 66, in <module>
    (y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
  File "/home/neil/src/cmm/a.py", line 61, in f
    return y, self.stop_gradient_all.f(x)
  File "/home/neil/src/cmm/a.py", line 28, in f
    return self.submodule.f(x)
  File "/home/neil/src/cmm/a.py", line 46, in f
    return self.dense(x)
  File "/home/neil/src/flax/flax/linen/linear.py", line 177, in __call__
    kernel = self.param('kernel',
flax.errors.ScopeCollectionNotFound: Tried to access "kernel" from collection "params"" in "/stop_gradient_all/map_variables(submodule)/dense" but the collection is emtpy. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeCollectionNotFound)
jheek commented 2 years ago

Ah that's because map_variables makes collections immutable. You need to provide a function that maps back the variables on output or pass init=True such that during init the map_variables isn't called.

Can you try passing init=True to nn.map_variables? I think that should fix your issue. Actually I accidentally removed the init=True from the example I copied from the docstring (updated my original examle) :S

NeilGirdhar commented 2 years ago

@jheek Thanks, that gets it to run, but it's still not reflecting a copy of x's parameters? It outputs:

[-1.5530705 -0.6934959  0.9631546] [ 0.246286    0.83799624 -0.91129684]
FrozenDict
    params=FrozenDict
        stop_gradient_all=FrozenDict
            submodule=FrozenDict
                dense=FrozenDict
                    bias=Jax Array (3,) float32
                            0.0000      0.0000      0.0000
                    kernel=Jax Array (3, 3) float32
                            0.3932      0.3981     -0.5165
                            0.1566     -0.0768     -0.2396
                           -0.3035      0.5167     -0.1552
        x=FrozenDict
            dense=FrozenDict
                bias=Jax Array (3,) float32
                        0.0000      0.0000      0.0000
                kernel=Jax Array (3, 3) float32
                       -0.0201     -0.6220      0.9425
                       -0.2652     -0.1386      0.8165
                       -1.2677      0.0672     -0.7958

So x and stop_gradient_all are different. Any idea how I can make it a mirror? I realize I nee to pass x somehow, but I'm still not sure how.