google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.89k stars 233 forks source link

It's very difficult to write libraries that support both Haiku and plain Jax #100

Open NeilGirdhar opened 3 years ago

NeilGirdhar commented 3 years ago

I would like to extend my fixed point solver to work in both Haiku and Jax. I thought it would be as simple as replacing jax.lax.scan with hk.scan, etc. Unfortunately, I get

 ValueError: hk.jit() should not be used outside of hk.transform. Use jax.jit() instead.

Perhaps I'm missing something, but would it be possible to reverse the design decision (https://github.com/deepmind/dm-haiku/pull/17) to raise an error and instead simply fall back to the Jax version of the command if the stateful context isn't needed?

Also, just out of curiosity, but is jacfwd broken in Haiku, or does it not need a stateful wrapper?

@trevorcai WDYT?

trevorcai commented 3 years ago

I think that's fine. I (personally) won't be able to make time to make the change, but happy to review a PR that would.

@tomhennigan for visibility

NeilGirdhar commented 3 years ago

Okay, great, thanks for the quick reply. I'll start working on a PR.

tomhennigan commented 3 years ago

Hi @NeilGirdhar and @trevorcai , I have some reservations about lifting this restriction. I think we should try as much as possible to avoid Haiku "leaking" into more of your program than it needs to be, and this is why hk.scan etc should only be used when you know that the code running is going to be hk.transform-ed.

One concrete reason why this might be a bad idea, is that you are now locking your users into using Haiku for NN, while they may prefer to use another library (e.g. Flax, Trax etc). The lock in comes from the fact that each OOP JAX library needs to provide you a drop in for scan et al when they are managing the state, so say if you want to support Flax you would need to be able to switch between hk and flax.scan (I assume they have one).

I'll reply on your other issues in a moment, but I wonder if it would be easier for your library to instead integrate with Haiku's pure functions that you get back from transform (e.g. f.apply)? If you are able to do so then your users would be able to use any NN library by simply producing the same function signature.

NeilGirdhar commented 3 years ago

Hi Tom, thanks for the detailed explanation.

when you know that the code running is going to be hk.transform-ed.

The problem is that you don't know whether the code is going to transformed. Let's look at some concrete examples.

In my exponential family library, efax, I calculate the Fisher information using grad here. There's no way for me to know whether this will be called from Jax code or Haiku.

The problem I'm actually running into is in tjax with my fixed point solver here. Same as before, there's no way to know whether this will be called by Jax or Haiku.

One concrete reason why this might be a bad idea, is that you are now locking your users into using Haiku for NN,

Yes, I agree with your point. I would love to hear alternatives. The issue is that this code leaks tracers if it's called from Haiku and isn't properly wrapped. I would love it if my libraries didn't have to know about Haiku.

it would be easier for your library to instead integrate with Haiku's pure functions that you get back from transform

Same as with the other problem, it's not possible because these functions are called deep in the code. If they were the outermost thing (for example, if I was always only calculating the Fisher information of the whole network, or the fixed point iteration was only ever applied to the whole network), then I agree that this is a reasonable workaround.

Maybe, to prevent the locking-in behavior, we could come up with a way for the wrappers to be registered with Jax? It would be more work on the Jax side, but it would be much nicer for users. Something like this:

That makes things easier for users since they never have to remember to use hk.jit, etc. That's one less gotcha for new users (I just burned a day on that one).

Another benefit is that it would simplify Haiku's interface by eliminating the user-facing functions hk.jit, etc.

Finally, no one is trapped in any choice, and my libraries don't even need to know about Haiku. They just blindly use jax.jit and everything works since Jax will appropriately delegate.

What do you think?

tomhennigan commented 3 years ago

I really like the idea of enabling libraries to override/monkey patch jit et al and I think I've discussed this before with @shoyer and @jekbradbury although I can't find the relevant issue. Do either of you remember if we filed a github issue?

tomhennigan commented 3 years ago

Ah here is the issue google/jax#4117.

NeilGirdhar commented 3 years ago

@tomhennigan Nice! I love the way they implemented that. I'm not sure what I should do while I wait for that pull request to be merged. Maybe I should merge that PR locally to my version of Jax and then add the appropriate Haiku calls (if you haven't got those somewhere already?)

tomhennigan commented 3 years ago

I think that sounds good, It looks like the PR mostly needs rebasing and someone to merge it so I don't think you'll need to wait for too long.

Wrt how/if we make this a part of Haiku I think we'll need to think quite carefully about whether there are performance implications (e.g. transforms like jax.checkpoint) before enabling by default. I'd suggest for now we make it easy for users to opt-in to Haiku's behaviours with a decorator:

_TRANSFORMS = {'lax.scan': hk.scan}  # etc ..

def override_jax_transforms(f):
  @functools.wraps
  def wrapped(*args, **kwargs):
    with jax.override_context(_TRANSFORMS):
      return f(*args, **kwargs)
  return wrapped

# user code
@hk.experimental.override_jax_transforms
def f(x):
  ...

# everything else stays the same
f = hk.transform(f)
params = f.init(..)
NeilGirdhar commented 3 years ago

Looks great! Looking forward to this.

shoyer commented 3 years ago

Yes, I'm happy to revive google/jax#4117 :)

My original motivation was actually exactly this issue: we have code that we want to support both Haiku and JAX. So far we've gotten around this by writing our own stateful version of higher order functions like jit for switching back and forth, but that isn't very extensible.

NeilGirdhar commented 3 years ago

I'm working with @shoyer's pull request, and entering the context manager, but I'm getting

ValueError: hk.while_loop does not support initialization (since we cannot statically determine if your loop will run at least once). Please use `hk.running_init` to run the body unconditionally:

    if hk.running_init():
      # Unconditionally connect the module at init time.
      val = module(val)
    else:
      val = hk.while_loop(lambda val: val.mean() < 1, module, val)

First off, what about parameters that are initialized by the cond_fun? Are those missed?

Also, I can do something like what's recommended my code, but this is very painful for libraries that don't want to know anything about Haiku. Why not just make hk.while_loop assume that the loop will run once for the sake of initialization? That is, why not just have:

def hk.while_loop(cond_fun, body_fun, init_val):
  if not base.params_frozen():
    cond_fun(init_val)
    return body_fun(init_val)
  ...

I guess this doesn't work in some very weird cases where the parameter-getter in body_fun has a different initializer than the parameter-getter after the while loop.

If this is unacceptable, then what about simply forcing the user to have every parameter that's used in the condition or body of a while loop to have already been initialized before the while loop? Something like,

def hk.while_loop(cond_fun, body_fun, init_val):
  if not base.params_frozen():
    try:
      with base.assert_state_unchanged():
        cond_fun(init_val)
        return body_fun(init_val)
    except StateChangedError as e:
      raise hk.StateChangedError("""No part of the Haiku managed state can be initialized in a while_loop.

Try to initialize the state beforehand.  For example,
    # Unconditionally initialize the state.
    jax.initialize(cond_fun, init_val)
    jax.initialize(body_fun, init_val)
    val = hk.while_loop(cond_fun, body_fun, init_val)""") from e
    ...

The problem with the latter approach is that you might have to convince the Jax team to add some kind of corresponding hook, like

@overrideable('initialize')
def initialize(f, *args, **kwargs):
  pass

so that libraries can run this initialization code without knowing about Haiku. Haiku would then override it to

def initialize(f, *args, **kwargs):
  if hk.running_init():
    f(*args, **kwargs)

Thoughts?

tomhennigan commented 3 years ago

Sorry for the long delay, it's been a busy few weeks and I've not had the headspace to dig into this.

First off, what about parameters that are initialized by the cond_fun? Are those missed?

IIRC cond has a requirement that the output structure of each branch must be the same, so I think by construction cond will only support parameter creation if the same creation happens in both branches. Otherwise you will get an error from JAX.

Since we know one of the branches will run, and both branches create/use the same params we can safely support creation in cond.

what about simply forcing the user to have every parameter that's used in the condition or body of a while loop to have already been initialized before the while loop?

This is what is implemented, we enforce this via requiring you to have if hk.running_init(): y = body() else: y = hk.while_loop(..).

I guess what you're asking is whether we can allow hk.while_loop but only throw the error if you update some state in the body (and we're running the init fn)? I think that could be fine, but in that case it would be appropriate to use lax.while_loop directly.

@shoyer's pull request

It seems like there has not been much movement on here. An alternative solution might be to document some alternative designs that would allow library authors to accept functionally impure code and make use of jax transforms. For example instead of:

def my_library_f(f, x):
  x = jax.some_transform(f)(x)
  ...
  return x

We could suggest libraries support users passing in the transforms to use:

def my_library_f(f, x, some_transform=jax.some_transform):
  x = some_transform(f)(x)
  return x

my_library_f(hk.Linear(1), x, hk.some_transform)

In the while loop case, users can then work around any library specific restrictions (as we have in hk.while_loop) without the library needing to care:

if hk.running_init():
  x = hk.Linear(1)(x).reshape(..)
else:
  x = my_library_f(hk.Linear(1), x, hk.some_transform)

Thoughts?

There are other ways to solve the "state management" problem. Haiku, Flax et al have point solutions outside JAX (which are convenient at first, but have some very sharp edges when combined with jax transforms and other libraries). JAX itself could support "implicit state" so at least the sharp edges would be consistent across stateful libraries. Doing so without losing the beautiful simplicity JAX's current explicit data flow design will be a challenge. I know that @LenaMartens, @mattjj, @jekbradbury and others have been thinking about this for a while, but there isn't a clear solution.

NeilGirdhar commented 3 years ago

(I'm sick in bed, so apologies if this reply doesn't make sense.)

IIRC cond has a requirement that the output structure of each branch must be the same, so I think by construction cond will only support parameter creation if the same creation happens in both branches. Otherwise you will get an error from JAX.

Sorry, I wasn't clear. If you look at the error message in hk.while_loop, it says to

    if hk.running_init():
      # Unconditionally connect the module at init time.
      val = module(val)
    else:
      val = hk.while_loop(lambda val: val.mean() < 1, module, val)

I was just saying that you might also want to recommend cond(val) in the initialization phase to make sure that state created by the condition shows up in the initialized state.

I guess what you're asking is whether we can allow hk.while_loop but only throw the error if you update some state in the body (and we're running the init fn)?

I'm just saying that ultimately I figured that hk.while_loop would be a replacement for jax.lax.while_loop, so we can't do if hk.running_init. I'm suggesting modifying the code for hk.while_loop to only throw the error if you update state that's not already initialized. Updating state that is initialized should be fine?

We could suggest libraries support users passing in the transforms to use…

To my eyes, @shoyer's solution is extremely elegant by comparison.

and others have been thinking about this for a while, but there isn't a clear solution.

Yes, I understand your point. Designing this well is going to make a really big difference. I appreciate all of the thought you all are putting into this. Jax is marvel of beautiful design, and Haiku is getting there!

shoyer commented 3 years ago

We decided to not merge https://github.com/google/jax/pull/4117.

But let me share how we've solved this problem in our own codebase, using our own versions of higher order functions like scan.

First, we define a version of scan that does the right thing for initialization:

import jax
import jax.numpy as jnp
import contextlib

_INITIALIZING = False

@contextlib.contextmanager
def init_context():
  global _INITIALIZING
  assert not _INITIALIZING
  _INITIALIZING = True
  yield
  _INITIALIZING = False

def init_safe_scan(f, init, xs, length=None, default_scan=jax.lax.scan):
  # version of lax.scan that allows for use under flax/haiku initialization 
  if _INITIALIZING:  # could also use hk.running_init() here
    xs_flat, treedef = jax.tree_flatten(xs)
    if length is None:
      length, = {x.shape[0] for x in xs_flat}
    x0 = jax.tree_unflatten(treedef, [x[0] for x in xs_flat])
    carry, y0 = f(init, x0)
    ys = jax.tree_multimap(lambda *z: jnp.stack(z), *(length * [y0]))
    return carry, ys
  return default_scan(f, init, xs, length)

Then in Haiku, you can write something like:

import haiku as hk

def neural_net(x):
  return hk.Linear(5)(x)

def haiku_init_safe_scan(f, init, xs, length=None):
  return init_safe_scan(f, init, xs, length=None, default_scan=hk.scan)

def my_model(step_fn):
  def doubled_step(x):
    y, _ = haiku_init_safe_scan(
      lambda x, _: (step_fn(x), _), init=x, xs=jnp.arange(2))
    return y
  return doubled_step

rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 5])
forward = hk.transform(my_model(neural_net))
with init_context():
  params = forward.init(rng, x)
print(params)  # only a single set of weights
logits = forward.apply(params, rng, x)
print(logits)  # does not crash

Presumably this sort of thing could be done for most/all higher order functions in JAX. It doesn't even have to be library specific, so I can imagine this being a good fit for a third-party library or perhaps even a jax.experimental module (but not builtin to JAX core).