keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.3k stars 19.38k forks source link

current rng setup is full of footguns in jax #18426

Open GallagherCommaJack opened 11 months ago

GallagherCommaJack commented 11 months ago

right now unseeded calls to e.g. keras.random.uniform are going to acquire static seeds at trace time. this has a few undesirable consequences:

1) subsequent calls will have the same randomness each time (e.g. dropout will have a fixed mask instead of random each step) 2) the jax compiler cache will ~never hit, as the constant rng seed values will be different every time

to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the Scope. such an approach is fairly complex however, and there might be simpler options e.g. a single global rng state, which gets included with the training state in model.fit, unseeded rng calls would then do something along the lines of

state.seed, local_seed = jax.random.split(state.seed)
GallagherCommaJack commented 11 months ago

a reference re jax rng mechanics https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng

fchollet commented 11 months ago

Right -- random ops in Keras Core are basically always intended to be called with a SeedGenerator instance as the seed argument, since seed=None defaults to an integer seed, which is going to get back into the graph you're tracing. Integer seed only really work intuitively in eager mode.

This is a gotcha for sure. Maybe we can do something to resolve it.

Some considerations:

GallagherCommaJack commented 11 months ago

Is this a gotcha specifically in the seed=None case, or the more general seed=int case? I'm guessing the latter.

both, the int needs to come from somewhere

If we convert seed=None/int into a seed generator/variable in the background, we're going to need to be tracking it, since it needs to be updated by any train/test step function. Pretty easy to do for built-in methods, but the problem is that anyone writing custom training loops is going to have to be aware of that global state and take it into account.

if someone is writing a fully custom training loop in jax they will either be aware of this or immediately shoot their foot off. that said, there's probably some way to progressively disclose rng handling eg have a "base train step" that in jax increments the rng counter

Should there be one global seed variable (used in the seed=None case) shared by all unseeded random ops (with different values for each op) or should there be one seed variable per unseeded op?

I don't quite understand the distinction here... maybe you could write some pseudocode to clarify it?

Could an alternative be to disallow seed=None/int when tracing? It would only work in eager, and require you to pass your own SeedGenerator if you're tracing

ideally we would not require the user to pass rng info all the way through model init in jax but nowhere else (and we cannot assume that model init happens in eager mode, in distributed training there isn't necessarily enough space on any given accelerator to support that)

AakashKumarNain commented 11 months ago

If we convert seed=None/int into a seed generator/variable in the background, we're going to need to be tracking it, since it needs to be updated by any train/test step function. Pretty easy to do for built-in methods, but the problem is that anyone writing custom training loops is going to have to be aware of that global state and take it into account.

Yes but isn't that how we expect it? Anyone who writes JAX is aware of the PRNG handling, and the consequences of not handling it properly. Anything with the PRNG should be explicit rather than implicit. It also makes debugging easy

fchollet commented 11 months ago

The difficulty is that there will be some reference to a RNG seed variable that you'll have to take into account, something like --

def fn(variables):
    trainable_variables = ...
    non_trainable_variables = ...
    return (trainable_variables, non_trainable_variables, keras.random.global_rng_seed())

variables = (trainable_variables, non_trainable_variables, keras.random.global_rng_seed())
fn(variables)

Anyone writing custom training loops is going to need to know about this API. If they forget it -- well, back to the current status quo, which is that their unseeded RNG calls are unchanged across iterations. Maybe not that bad if you think of it like that.

GallagherCommaJack commented 11 months ago

Anyone writing custom training loops is going to need to know about this API. If they forget it -- well, back to the current status quo, which is that their unseeded RNG calls are unchanged across iterations. Maybe not that bad if you think of it like that.

It seems like a clear improvement over the current scenario, where you not only have to remember to do this, but it's not even suggested by the default APIs, which make it very difficult to do this.

martin-gorner commented 11 months ago

+1 to proper RNG management. This is extremely important to JAX users. JAX offers good reproducibility out of the box (with some RNG learning curve for the user). It's fine is Keras can simplify the API with automatic jax.random.split(s) in the right places but "reproducibility out of the box" should remain.

fchollet commented 11 months ago

I looked at this more closely. What I can propose is this:

  1. Unseeded random ops use a global SeedGenerator
  2. The state of the global seed generator (KerasVariable of size (2,)) is accessible via keras.random.global_rng_state()
  3. You can update the state any time you want via assign or assign_add, just like any other Keras variable.

So you can manage your unseeded random ops calls like this:

@jax.jit
def jitted_random_numbers(seed):
    rng_state = keras.random.global_rng_state()
    rng_state.assign(seed)
    x = keras.random.normal(...)
    y = keras.random.uniform(...)
    return x, y, rng_state.value

You could even have something a bit more intuitive like this:

@jax.jit
def jitted_random_numbers(seed):
    keras.random.set_global_rng_state(seed)
    x = keras.random.normal(...)
    y = keras.random.uniform(...)
    return x, y, keras.random.global_rng_state().value  # the .value won't be necessary in a future JAX version

The default behavior would be unchanged from now (unseeded random op calls are deterministic per traced function execution).

Does that work?

+1 to proper RNG management. This is extremely important to JAX users

We have that already, via the SeedGenerator class. But it does require that you seed your random op calls with a SeedGenerator. The general pattern is this:

class RandomLayer(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1234)

    def call(self, x):
         return x + keras.random.normal(..., seed=self.seed_generator)

layer = RandomLayer()  # layer.non_trainable_variables includes the seed
non_trainable_vars = layer.non_trainable_variables
outputs, non_trainable_vars = layer.stateless_call(non_trainable_vars, x)  # This returns the updated seed value
martin-gorner commented 11 months ago

The default behavior would be unchanged from now

If this results in Dropout layer silently not working (same dropout mask at each invocation), then this is not a good solution. You suggested catching this with an error in the compiled case. That sounds right. How doable is it?

martin-gorner commented 11 months ago

Wouldn't it be better to handle the rng at the model level, as part of the model state?

The standard training loop with dropout in JAX usually looks like this (omitting non-trainable variables for simplicity but not RNGs:

# rng_key is an instance of jax.random.PRNGKey(seed)
# trainable_vars must be the first param as jax.grad differentiates agains the first param only
def stateless_loss(trainable_vars, rng_key, x,y): 
  y_pred = model.apply(trainable_vars, rng_key, x) # rng_key necessary in .apply or error
  loss_val = loss_fn(y, y_pred)
  return loss_val

stateless_grads = jax.grad(stateless_loss)

state.rng_key = jax.random.PRNGKey(0)
for x, y in dataset:
  # training step, could be a jitted function
  loss_val, grads = stateless_grads(state.trainable_vars, state.rng_key, x,y)
  updates, state.optimizer_state = optimizer.update(grads, state.optimizer_state)
  state.trainable_vars = optax.apply_updates(state.trainable_vars, updates)

  # new rng key so that dropout gives a different value on next iteration
  state.rng_key, _ = jax.random.split(state.rng_key)
martin-gorner commented 11 months ago

Could Keras have something very similar ? Like this:


model.build(data, seed) # internally creates RNG keys for layers that
                        # need them (mostly dropout), stores them in model.rng_keys
state = (model.trainable_variables, model.non_trainable_variables, model.rng_keys, optimizer.variables)

# ignoring model.non_trainable_variables again for simplicity

def stateless_loss(trainable_vars, rng_keys, x,y): 
  y_pred = model.stateless_call(trainable_vars, rng_keys, x) # rng_keys necessary in call or error
  loss_val = loss_fn(y, y_pred)
  return loss_val

stateless_grads = jax.grad(stateless_loss)

state.rng_key = jax.random.PRNGKey(0)
for x, y in dataset:
  # training step, could be a jitted function
  loss_val, grads = stateless_grads(state.trainable_vars, state.rng_keys, x,y)
  updates, state.optimizer_state = optimizer.update(grads, state.optimizer_state)
  state.trainable_vars = optimizer.stateless_apply(state.trainable_vars, updates)

  # new rng keys so that dropouts gives a different value on next iteration
  state.rng_key[0], _ = jax.random.split(state.rng_key[0])
  state.rng_key[1], _ = jax.random.split(state.rng_key[1])
  # etc, there could be convenience function for this: model.advance_all_rng_keys()
martin-gorner commented 11 months ago

Compared to your proposal above, I'd like to build my dropout-like layers with SeedGenerator() without params and still be able to assign a seed to all these layers at once though the Model abstraction.

Same thing for weight initializers, btw.

fchollet commented 11 months ago

That's roughly how it works, except it's actually much simpler and more intuitive.

AakashKumarNain commented 11 months ago

You don't need to manually update your RNG state like this, it gets updated inside the model automatically. You just need to do outputs, non_trainable_variables = model.stateless_call(trainable_variables, non_trainable_variables, inputs) in a loop.

This sounds like a clean solution to me. The only thing that I would expect in this scenario is to a way to access the random state at any step to ensure that the rng is being handled properly. That way I can validate that everything on the model side is working as expected

martin-gorner commented 11 months ago

OK, that sounds good. Three questions:

1) Looking at the current implementation, is this how you should set seeds for deterministic training ?

2) What happens for people using just keras.layers.Dense()? Will this result in a different initialization at each instantiation? Is there a way to control the seed of all the initializers in a model at once without using kernel_initializer= and instance_initializer= in every layer explicitly? Maybe through keras.random.set_global_rng_state(seed)?

3) Looking at the implementation of SeedGenerator.next(), I see this line: self.state.assign((seed_state + 1) * 5387 % 933199) Shouldn't this be something involving the platform-specific random split APIs like tf.random.split or jax.random.split ? According to JAX docstrings, the theory behind the "split" mechanism is in this article. I have not read it yet.

AakashKumarNain commented 11 months ago

Also, jax.random.split is purely deterministic, I am not sure about tf.random.split

fchollet commented 11 months ago

for Dropout: keras.layers.Dropout(seed=123). SeedGenerator is built in so that each invocation gives a different dropout mask.

Yes

for a layer with weights is keras.layers.Dense(kernel_initializer=keras.initializers.RandomNormal(seed=123)). SeedGenerator not built in but that seems to be what users wants: the same random weights at each initialization.

Initializers are only meant to be called once, and integer-seeded initializers always return the same value, just like integer-seeded ops.

What happens for people using just keras.layers.Dense()? Will this result in a different initialization at each instantiation?

Yes

Is there a way to control the seed of all the initializers in a model at once without using kernel_initializer= and instance_initializer= in every layer explicitly?

Call keras.utils.set_random_seed(1337) at the start of your program. This provided full determinism, minus backend op level (certain GPU kernels) indeterminism which is handled differently by each framework (e.g. TF tf.config.experimental.enable_op_determinism()).

Shouldn't this be something involving the platform-specific random split APIs

No need. But if you do want to manage your random seed sequence yourself via whatever algorithm of your choice, you have that option (you can just subclass SeedGenerator and do your own thing, then pass your custom class instance around to your layers).

martin-gorner commented 11 months ago

OK, thank you for your answers. For droput-style random layers as well as weight initializers, this looks good. I especially like that the the same setup (i.e. seed=123) is the correct one in both case.

As for RNG splitting? I think the problem to solve is rng determinism, even in a distributed setting where execution order is not fully deterministic. The theory of why it is needed seems complicated (ref) and math-intensive. I don't have the full background so I will accept your conclusion that this mechanism is "not needed" at face value.

I fear however that most users will prefer relying on the standard "split" mechanism implemented in TF and JAX rather than investing the time to analyze and be convinced by the assertion that the mechanism is useless. If you have a proof of this assertion, please put it forward, but consider the difficulty of then communicating it to all users and convincing them.

AakashKumarNain commented 11 months ago

@fchollet I looked at the SeedGenerator again, and I have a few more doubts now. Apologies for so many questions but I think this is a very critical aspect to discuss.

I fear however that most users will prefer relying on the standard "split" mechanism implemented in TF and JAX rather than investing the time to analyze and be convinced by the assertion that the mechanism is useless.

I totally agree with @martin-gorner here. It's very hard to convince the users to validate another PRNG implementation in their daily workflow.

A simple thing to do would be to leverage JAX PRNG implementation in the SeedGenerator. Depending on the backend type, we should cast the rng accordingly, and use it everywhere. This reduces the burden of testing another pseudo random generator. What do you think?

fchollet commented 11 months ago

Sure, I'm open to having a split_seed backend op of some kind. Then we can use it in SeedGenerator.next().

AakashKumarNain commented 11 months ago

Thank you. I guess we can refactor our seed generator like this:

class SeedGenerator:
    def __init__(self, seed=None, **kwargs):
        if seed is None:
            seed = jax.random.PRNGKey(make_default_seed())
        self._initial_seed = seed

        def seed_initializer(*args, **kwargs):
            return self.backend.convert_to_tensor(np.asarray(seed), dtype="uint32")

        self.state = self.backend.Variable(
            seed_initializer,
            shape=(2,),
            dtype="uint32",
            trainable=False,
            name="seed_generator_state",
        )

    def split_seed(self, seed_state):
        return jax.random.split(seed_state)

    def next(self):
        seed_state = jnp.array(backend.convert_to_numpy(self.state), dtype="uint32")
        seed_state, seed_sub_state = self.split_seed(seed_state)
        self.state.assign(seed_sub_state)
        return seed_sub_state

def draw_seed(seed):
    from keras_core.backend import convert_to_tensor

    if isinstance(seed, SeedGenerator):
        return seed.next()
    elif isinstance(seed, int):
        seed = jax.random.PRNGKey(seed)
        return SeedGenerator(seed=seed, dtype="uint32")
    elif seed is None:
        return global_seed_generator().next()
    raise ValueError(
        "Argument `seed` must be either an integer "
        "or an instance of `SeedGenerator`. "
        f"Received: seed={seed} (of type {type(seed)})"
    )
AakashKumarNain commented 11 months ago

Thoughts @fchollet @martin-gorner @GallagherCommaJack ?

AakashKumarNain commented 11 months ago

Good news! I just tested the suggestions I made above for the SeedGenerator class, and I am sure this works perfectly! 🕺 Let me know your thoughts. I will make a PR accordingly

Here are the results from a model I had:

First run:

Screenshot 2023-08-09 at 4 48 39 PM

Second run:

Screenshot 2023-08-09 at 4 45 29 PM



Btw these are the results with JAX backend, model trained on 2 GPUs. The only thing that I am worried about is the lack of support of uint32 in TF. I am not sure if casting back to int32 can create a problem somehow

fchollet commented 11 months ago

Thank you. I guess we can refactor our seed generator like this:

That sounds good at a high level, but the split_seed method should be a backend function instead, with a different implementation in each backend. We should also not have any reference to PRNGKey outside of the JAX backend.

AakashKumarNain commented 11 months ago

@fchollet got it! I will make the changes accordingly, and will make a PR. Thanks for the pointers

AakashKumarNain commented 11 months ago

@fchollet I refactored the SeedGenerator class and made changes in the backend. For TF, and JAX we have a very uniform implementation but torch is a bit problematic. Why? Because torch changes the state of the generator implicitly as it consumes the bits of the torch.Generator(..) instance. Here is the simplified version now:


class SeedGenerator:
    def __init__(self, seed=None, **kwargs):
        custom_backend = kwargs.pop("backend", None)
        if kwargs:
            raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
        if custom_backend is not None:
            self.backend = custom_backend
        else:
            self.backend = backend

        if seed is None:
            seed = backend.random.make_default_seed()
        else:
            seed = backend.random.make_initial_seed(seed)

        if backend.backend() == "tensorflow":
            seed_dtype = "int32"
        else:
            seed_dtype = "uint32"

        self._initial_seed = seed
        self.state = self.backend.Variable(
            backend.convert_to_tensor(seed),
            shape=tuple(seed.shape),
            dtype=seed_dtype,
            trainable=False,
            name="seed_generator_state",
        )

    def next(self):
        seed_state = backend.convert_to_tensor(self.state)
        seed_state, seed_sub_state = backend.random.get_next_state(seed_state)
        self.state.assign(seed_sub_state)
        return seed_sub_state

def global_seed_generator():
    gen = global_state.get_global_attribute("global_seed_generator")
    if gen is None:
        gen = SeedGenerator()
        global_state.set_global_attribute("global_seed_generator", gen)
    return gen

def global_rng_state():
    return global_seed_generator().state

def draw_seed(seed):
    from keras_core.backend import convert_to_tensor

    if isinstance(seed, SeedGenerator):
        return seed.next()
    elif isinstance(seed, int):
        return SeedGenerator(seed=seed).next()
    elif seed is None:
        return global_seed_generator().next()
    raise ValueError(
        "Argument `seed` must be either an integer "
        "or an instance of `SeedGenerator`. "
        f"Received: seed={seed} (of type {type(seed)})"
    )


One way to handle the differences in torch is to check the backend type in SeedGenerator class and modify the returned values accordingly. For example, in TF and JAX, we get seed_state, seed_sub_state when we call next but in torch it would return the torch generator object, along with the current state.

Please let me know what you think. I can make a PR and we can discuss the modification within the PR itself. Would be much easier for you to review and comment

AakashKumarNain commented 11 months ago

I have figured out a way to make everything work seamlessly. Will make a PR tomorrow

aaarrti commented 8 months ago

Hi @GallagherCommaJack @fchollet, do you plan to resume work on this issue? What is the recommended way of handling RNGs? I believe it would be much appreciated, if you could add a documentation (or tutorial) covering this topic 🙃.