google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.93k stars 628 forks source link

AutoRegressive Decoding currently fails if input prompt > 1 #1317

Open patrickvonplaten opened 3 years ago

patrickvonplaten commented 3 years ago

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

I want to run a model based on flax.linen.SelfAttention in auto-regressive mode and pass an input prompt > 1. This however does not seem possible at the moment, e.g.:

import jax
import jax.numpy as jnp
from flax.linen import SelfAttention

attn_layer = SelfAttention(1, decode=True, use_bias=False)

batch_size = 1
max_decoder_length = 4
hidden_size = 2
prompt_length = 2   # setting this to 1 would work

init_variables = attn_layer.init(jax.random.PRNGKey(0), jnp.ones((batch_size, max_decoder_length, hidden_size)), deterministic=True)

params = init_variables["params"]
cache = init_variables["cache"]

dummy_prompt = jnp.arange(batch_size * prompt_length * hidden_size).reshape((batch_size, prompt_length, hidden_size))

output, cache = attn_layer.apply({"params": params, "cache": cache}, dummy_prompt, mutable=["cache"], deterministic=True)

leads to an error. Also check this notebook.

What you expected to happen:

Instead, the code should work and the first len(prompt_length) cache variables should be stored.

Logs, error messages, etc:

~/python_bin/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    273     _context.module_stack.append(self)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:
    277         filter_fn = _context.capture_stack[-1]

~/python_bin/flax/linen/attention.py in __call__(self, inputs_q, inputs_kv, mask, deterministic)
    265         expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
    266         if expected_shape != query.shape:
--> 267           raise ValueError('Autoregressive cache shape error, '
    268                            'expected query shape %s instead got %s.' %
    269                            (expected_shape, query.shape))

ValueError: Autoregressive cache shape error, expected query shape (1, 1, 1, 2) instead got (1, 2, 1, 2).

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

See code/colab above

patrickvonplaten commented 3 years ago

This issue could be solved with this PR: https://github.com/google/flax/pull/1316

marcvanzee commented 3 years ago

I want to run a model based on flax.linen.SelfAttention in auto-regressive mode and pass an input prompt > 1.

Apologies if I am misunderstanding, but for fast decoding we deliberately only feed 1 token at a time and assume the caller iterates over the input and maintains the cache. So you could do something like this in your code:

import jax
import jax.numpy as jnp
from flax.linen import SelfAttention
from jax import lax

attn_layer = SelfAttention(1, decode=True, use_bias=False)

batch_size = 1
max_decoder_length = 4
hidden_size = 2
prompt_length = 2
output_dim = 2

init_variables = attn_layer.init(jax.random.PRNGKey(0), jnp.ones((batch_size, max_decoder_length, hidden_size)), deterministic=True)

params = init_variables["params"]
cache = init_variables["cache"]

prompts = jnp.arange(batch_size * prompt_length * hidden_size).reshape((batch_size, prompt_length, hidden_size))
outputs = jnp.zeros((prompts.shape))  # This will be filled one token at a time

for i in range(prompt_length):
  prompt = jax.lax.slice_in_dim(prompts, i, i+1, axis=1)
  output, mutable_vars = attn_layer.apply({"params": params, "cache": cache}, prompt, mutable=["cache"], deterministic=True)
  outputs = lax.dynamic_update_slice(outputs, output, (0, i, 0))
  cache = mutable_vars["cache"]  # Update cache for next iteration

print(outputs)

Are you saying that this approach doesn't work for your use case? If so, could you please explain why not?

patrickvonplaten commented 3 years ago

Okey, I see! Thanks a lot for the explicit example @marcvanzee.

This would work for me, but wouldn't it be quite inefficient to do it this way?

E.g., If someone wants to continue generating text from an input prompt that has lots of tokens, but only a few tokens are to be generated e.g. few shot prompting (in GPT3-style):

prompt: "Translate to German. Hello today is a nice day. Hallo heute ist ein schoener Tag. Translate to German. How are you? Wie geht es dir? Translate to German. What time is it?"

Let's say the prompt has a token length of 100. In this case, we would do a 100 forward passes just to initialize the cache and only the 100th forward pass would generate the first meaning ful output token -> the 101th token, no?

It could be feasible to just run a single forward pass with 100 tokens to initialize the cache, output the 101th token and then successfully do auto-regressive decoding, no?

marcvanzee commented 3 years ago

Ah ok, interesting! I think we haven't yet done this kind of decoding, since in our use cases each token is always generated autoregressively, so you have to do to token by token. However, it this is not a requirement then initializing the cache at once seems more efficient.

@levskaya what do you think?

patrickvonplaten commented 3 years ago

Maybe for some context, I'm currently implementing this for FlaxGPT2 here

I have a working solution where I split generation into 2 steps:

  1. Run the input_prompt to generate the next token and initialiize the cache
  2. Run auto-regressive decoding with query_length always equal to 1

An examplary test is here.

In order to make the test work, I changed the self.is_decoder past as can be seen here

levskaya commented 3 years ago

Hey! I commented on #1316 but hadn't seen these comments when writing that. So I think it sounds like you want to do a single-pass "teacher-forced" cache-initialization. We'd need to add a third "operating mode" for the attention layer (basically, run normally but then stuff the keys and values for the first N tokens into a cache).