Closed NeilGirdhar closed 2 years ago
So in the haiku version of the code you are solving this "outside" of Haiku by operating on the variables dict directly and making 4 copies. In Flax you could do something similair for example by using flax.travere_util.flatten_dict
. If you want to do this inside a linen Module you could use nn.map_variables
where the mapping is basically identity but with stop_gradient applied to some or all of the params (again you cam make your life easy here by using flatten_dict).
Here's an sketch of what that would look like:
from flax import traverse_util
def selective_stop_grad(variables):
flat_vars = traverse_util.flatten_dict(variables)
new_vars = {k: lax.stop_gradient(v) if some_filter_fn(k) else v for k, v in flat_vars.items()}
return traverse_util.unflatten_dict(new_vars)
class MySGModule(nn.Module):
@nn.compact
def __call__(self, x):
MySGSubModule = nn.map_variables(MySubModule, "params", selective_stop_grad, init=True)
return MySGSubModule(...)(x)
@jheek the map_variables
solution is great, I like it a lot!
A HOWTO about freezing parameters using this strategy would be great.
@jheek
I've been trying to implement your solution, but I can't seem to get it working for me. Here's roughly what I have:
from __future__ import annotations
from collections.abc import Callable
from dataclasses import asdict
from typing import Any, Generic, TypeVar
import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey
T = TypeVar('T', bound=nn.Module)
class StopGradientModule(nn.Module, Generic[T]):
filter_f: Callable[[tuple[str, ...]], bool]
submodule_cls: Callable[..., T]
def setup(self) -> None:
self.submodule = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient)
def __call__(self, module: T) -> T:
return self.submodule(**asdict(module))
def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
flat_vars = traverse_util.flatten_dict(variables) # type: ignore[no-untyped-call]
new_vars = {k: stop_gradient(v)
if self.filter_f(k) else v
for k, v in flat_vars.items()}
return traverse_util.unflatten_dict(new_vars) # type: ignore[no-untyped-call]
class X(nn.Module):
def setup(self) -> None:
self.dense = nn.Dense(10)
# stop_gradient_all is a copy of self whose parameters are identical, but whose parameter
# cotangents are always zero.
self.stop_gradient_all = StopGradientModule(lambda _: True, X)
def f(self, x: Any) -> Any:
return self.dense(x), self.stop_gradient_all(self).dense(x)
print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f))
gives
Traceback (most recent call last):
File "/home/neil/src/cmm/a.py", line 44, in <module>
print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f))
File "/home/neil/src/cmm/a.py", line 41, in f
return self.dense(x), self.stop_gradient_all(self).dense(x)
File "/home/neil/src/cmm/a.py", line 37, in setup
self.dense = nn.Dense(10)
ValueError: Duplicate use of scope name: "dense"
I realize that this is currently a recursive mess, and I'm exploring the simplest way of accomplishing what I'm trying to accomplish.
I'm still trying to get this working. Here's what I have now:
from __future__ import annotations
from collections.abc import Callable
from typing import Any, Generic, TypeVar
import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey
from tjax import print_generic
T = TypeVar('T', bound=nn.Module)
class StopGradientModule(nn.Module, Generic[T]):
filter_f: Callable[[tuple[str, ...]], bool]
submodule_cls: Callable[..., T]
def setup(self) -> None:
mapped_cls = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient,
methods=['f'])
self.submodule = mapped_cls()
def f(self, x: Any) -> Any:
print("Calling")
return self.submodule.f(x)
def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
flat_vars = traverse_util.flatten_dict(variables) # type: ignore[no-untyped-call]
new_vars = {k: stop_gradient(v)
if self.filter_f(k) else v
for k, v in flat_vars.items()}
return traverse_util.unflatten_dict(new_vars) # type: ignore[no-untyped-call]
def __call__(self):
assert False
class X(nn.Module):
def setup(self) -> None:
self.dense = nn.Dense(3)
def f(self, x: Any) -> Any:
return self.dense(x)
def __call__(self):
assert False
class Y(nn.Module):
def setup(self) -> None:
self.x = X()
# stop_gradient_all is a copy of x whose parameters are identical, but whose parameter
# cotangents are always zero.
self.stop_gradient_all = StopGradientModule(lambda _: True, X)
def f(self, x: Any) -> Any:
y = self.x.f(x)
return y, self.stop_gradient_all.f(x)
def __call__(self):
assert False
(y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
print(y, y_prime)
print_generic(variables)
gives
Traceback (most recent call last):
File "/home/neil/src/cmm/a.py", line 66, in <module>
(y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
File "/home/neil/src/cmm/a.py", line 61, in f
return y, self.stop_gradient_all.f(x)
File "/home/neil/src/cmm/a.py", line 28, in f
return self.submodule.f(x)
File "/home/neil/src/cmm/a.py", line 46, in f
return self.dense(x)
File "/home/neil/src/flax/flax/linen/linear.py", line 177, in __call__
kernel = self.param('kernel',
flax.errors.ScopeCollectionNotFound: Tried to access "kernel" from collection "params"" in "/stop_gradient_all/map_variables(submodule)/dense" but the collection is emtpy. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeCollectionNotFound)
Ah that's because map_variables makes collections immutable. You need to provide a function that maps back the variables on output or pass init=True such that during init the map_variables isn't called.
Can you try passing init=True
to nn.map_variables? I think that should fix your issue. Actually I accidentally removed the init=True from the example I copied from the docstring (updated my original examle) :S
@jheek Thanks, that gets it to run, but it's still not reflecting a copy of x's parameters? It outputs:
[-1.5530705 -0.6934959 0.9631546] [ 0.246286 0.83799624 -0.91129684]
FrozenDict
params=FrozenDict
stop_gradient_all=FrozenDict
submodule=FrozenDict
dense=FrozenDict
bias=Jax Array (3,) float32
0.0000 0.0000 0.0000
kernel=Jax Array (3, 3) float32
0.3932 0.3981 -0.5165
0.1566 -0.0768 -0.2396
-0.3035 0.5167 -0.1552
x=FrozenDict
dense=FrozenDict
bias=Jax Array (3,) float32
0.0000 0.0000 0.0000
kernel=Jax Array (3, 3) float32
-0.0201 -0.6220 0.9425
-0.2652 -0.1386 0.8165
-1.2677 0.0672 -0.7958
So x
and stop_gradient_all
are different. Any idea how I can make it a mirror? I realize I nee to pass x
somehow, but I'm still not sure how.
To motivate this feature request, I'll explain what I'm currently doing (without Flax), and the other solutions I've considered. Then I'll suggest some Flax solution.
Problem:
In the process of inferring one of my modules, I need to mask varying subsets of the weights with stop-gradient in a single function:
(click for long snippet)
```python def infer(encoding: EncodingElement, observation: PoolingMessage, prediction: PredictionMessage, rng: Generator, weights: FrozenVariableDict) -> TwoPassEncodingConfiguration: sampler_rng, code_rng = rng.split() # Create four copies of the weights: # * weights_sg has stop_gradient applied to all weights, and # * the other three have stop_gradient applied to different partitions of # the weights. weights_sg, weights_g, weights_c, weights_e = _stop_gradient_on_some_weights(weights) # Inference ------------------------------------------------------------------------------------ # This function uses weights_sg so this calculation won't poison the weight cotangents. # However, cotangents still propagate back to observation. code_message = encoding.code_message(observation, weights_sg) # GLN loss ------------------------------------------------------------------------------------- # The scan parameters depend on weights_g. encoding_parameters_g = SamplerParameters(observation, prediction, weights_g) # This use of stop_gradient prevents the cotangents from propagating back from the scan through # to the observation. initial_code_message = stop_gradient(code_message) # This class manages an iterated function (a scan) sampler = EncodingSampler(encoding) sampler_iterations = encoding.inference_parameters.sampler_iterations initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng) # This is an extremely computationally expensive scan. sampler_state, sampler_trajectory = sampler.sample_trajectory( encoding_parameters_g, initial_sampler_state, sampler_iterations, None) # We calculate a GLN loss, which can only affects the subset of weights in weights_g. gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss) / sampler_iterations) iterative_code_message = sampler_state.code_message # Code loss ------------------------------------------------------------------------------------ # The code loss trains the code and selection links to produce a code message that predicts the # code message that we inferred by iteration. # This is the same code_message function as above, but uses weights_c. c_code_message = encoding.code_message(observation, weights_c, rng=code_rng, use_code_signal_noise=True) # When this loss is minimized only the weights that are not marked stop-gradient in weights_c # are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient # to its outputs. code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence) - c_code_message.log_presence)) code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value) - c_code_message.code_value)) code_loss = code_presence_loss + code_value_loss # Snipped a lot of code here that uses weights_e and produces output primals. return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss) # Below is the code that uses Haiku to partition the weights and apply stop gradient to different # partitions. _module_classes = [{'gln'}, {'code_value', 'code_presence'}, {'explanation'}] def _module_predicate(module_name: str, name: str, value: Array) -> int: prefix = module_name.split('/')[0] for i, prefix_set in enumerate(_module_classes): if prefix in prefix_set: return i raise RuntimeError # I was using Haiku before, but I'll have to port this to Flax somehow. def _partition_by_module(weights: FrozenVariableDict) -> tuple[FrozenVariableDict, ...]: return hk.data_structures.partition_n(_module_predicate, # type: ignore[arg-type] weights, len(_module_classes)) def _stop_gradient_on_some_weights(weights: FrozenVariableDict) -> list[FrozenVariableDict]: weights_sg = stop_gradient(weights) weights_p = _partition_by_module(weights) weights_sg_p = _partition_by_module(weights_sg) return ([weights_sg] + [hk.data_structures.merge(weights_pi, *[weights_sg_pi for j, weights_sg_pi in enumerate(weights_sg_p) if i != j]) for i, weights_pi in enumerate(weights_p)]) ```
Non-solution:
I discussed this with @cgarciae and brainstormed a non-solution: I could try to put the "C", "G", and "E" weights into different "collections". And then run inference three times. This doesn't work because:
Possible Flax interface:
We came up with two Flax interfaces that might work.
I suggested some kind of context manager
flax.linen.stop_gradient
:(click for long snippet)
```python def infer(encoding: EncodingElement, observation: PoolingMessage, prediction: PredictionMessage, rng: Generator, weights: FrozenVariableDict) -> TwoPassEncodingConfiguration: sampler_rng, code_rng = rng.split() # Inference ------------------------------------------------------------------------------------ # This function uses weights_sg so this calculation won't poison the weight cotangents. # However, cotangents still propagate back to observation. with nn.stop_gradient(lambda c: True): code_message = encoding.code_message(observation) # GLN loss ------------------------------------------------------------------------------------- encoding_parameters_g = SamplerParameters(observation, prediction) # This use of stop_gradient prevents the cotangents from propagating back from the scan through # to the observation. initial_code_message = stop_gradient(code_message) sampler = EncodingSampler(encoding) sampler_iterations = encoding.inference_parameters.sampler_iterations initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng) # The scan parameters depend on weights_g. with nn.stop_gradient(lambda c: c.name.starts_with('gln')): # This class manages an iterated function (a scan) # This is an extremely computationally expensive scan. sampler_state, sampler_trajectory = sampler.sample_trajectory( encoding_parameters_g, initial_sampler_state, sampler_iterations, None) # We calculate a GLN loss, which can only affects the subset of weights in weights_g. gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss) / sampler_iterations) iterative_code_message = sampler_state.code_message # Code loss ------------------------------------------------------------------------------------ # The code loss trains the code and selection links to produce a code message that predicts the # code message that we inferred by iteration. # This is the same code_message function as above, but uses weights_c. with nn.stop_gradient(lambda c: c.name.starts_with('code')): c_code_message = encoding.code_message(observation, rng=code_rng, use_code_signal_noise=True) # When this loss is minimized only the weights that are not marked stop-gradient in weights_c # are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient # to its outputs. code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence) - c_code_message.log_presence)) code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value) - c_code_message.code_value)) code_loss = code_presence_loss + code_value_loss # Snipped a lot of code here that uses weights_e and produces output primals. return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss) ```
Cristian suggested a lifting transformation like those found in
flax.core.lift
. I'm still learning how these work, so I can't yet sketch what this might look like.Possible side benefits
Besides applying stop-gradient, this kind of system may be able to do other things with parameters such as:
Of course, that's beyond this feature request, but I mention these ideas as something to keep in mind when considering solutions.
Conclusion
Am I missing an easy solution to my problem? If not, I will need to solve this problem in order to use Flax since this use of stop-gradient is integral to my research. Thanks for reading!