google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.03k stars 2.66k forks source link

jax.experimental.rnn.lstm() violates JIT invariance #19177

Open Sun-Xiaohui opened 6 months ago

Sun-Xiaohui commented 6 months ago
import jax
import jax.numpy as jnp
from jax.experimental import rnn

batch_size = 8
input_size = 8
hidden_size = 4
num_layers = 1
bidirectional = True
num_directions = 2 if bidirectional else 1

seq_lengths = jnp.array([4, 5, 4, 1, 1, 1, 1, 1], dtype=jnp.int32)
max_seq_length = 5

root_key = jax.random.PRNGKey(1)
k1, k2, k3, k4 = jax.random.split(root_key, 4)
x = jax.random.normal(
    k1, (batch_size, max_seq_length, input_size), dtype=jnp.float32)
h_0 = jax.random.normal(
    k2, (num_directions * num_layers, batch_size, hidden_size),
    dtype=jnp.float32)
c_0 = jax.random.normal(
    k3, (num_directions * num_layers, batch_size, hidden_size),
    dtype=jnp.float32)
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
                               bidirectional)

out = jax.jit(rnn.lstm)
y, h, c = out(
    x,
    h_0,
    c_0,
    weights,
    seq_lengths=seq_lengths,
    input_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=False,
    bidirectional=bidirectional)

The above code reports an error: jax._src.errors.UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums.

How can rnn.lstm() be jited? Thanks a lot!

JAX version : 0.4.8 jaxlib : 0.4.7+cuda12.cudnn88

jakevdp commented 6 months ago

Hi - the lstm function accepts a number of arguments that must be static (e.g. array sizes, boolean flags, etc). If you are wrapping it in jit, then you should declare these using static_argnames or static_argnums. It might look something like this:

out = jax.jit(rnn.lstm, static_argnames=['input_size', 'hidden_size', 'num_layers', 'dropout', 'bidirectional', 'precision'])
Sun-Xiaohui commented 6 months ago

@jakevdp Thanks, that helps! However, I found that if I have a Module developed based on rnn.lstm, forward value between model.apply() and jax.jit(model.apply)() does not match, here is a minimal reproducible example

from flax.linen import Module
import jax.experimental.rnn as rnn
import jax.numpy as jnp
import jax

class LSTM(Module):

  input_size: int
  hidden_size: int
  num_layers: int = 1
  batch_first: bool = True
  bidirectional: bool = False

  def setup(self):
    self.w = self.param(
      'lstm_weights',
      rnn.init_lstm_weight,
      self.input_size,
      self.hidden_size,
      self.num_layers, self.bidirectional
    )

  def __call__(self, input_seq):
    batch_size = input_seq.shape[0] if self.batch_first else input_seq.shape[1]
    num_direction = 2 if self.bidirectional else 1
    h_0 = jnp.zeros(shape=(self.num_layers * num_direction, batch_size, self.hidden_size), dtype=input_seq.dtype)
    c_0 = jnp.zeros(shape=(self.num_layers * num_direction, batch_size, self.hidden_size), dtype=input_seq.dtype)
    if not self.batch_first:
      input_seq = jnp.moveaxis(input_seq, 0, 1)
    seq_lengths = jnp.full((batch_size,), input_seq.shape[1], dtype=jnp.int32)
    output, h_n, c_n = rnn.lstm(
      input_seq,
      h_0,
      c_0,
      self.w,
      seq_lengths,
      self.input_size,
      self.hidden_size,
      self.num_layers,
      False,
      self.bidirectional
    )
    if not self.batch_first:
      output = jnp.moveaxis(output, 0, 1)
    return output, (h_n, c_n)

def main():
  sequence_length = 28
  batch_size = 512
  input_size = 28
  hidden_size = 128

  model = LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=False)
  variables = model.init(
    jax.random.PRNGKey(1),
    jnp.ones([sequence_length, batch_size, input_size], dtype=jnp.float32)
  )
  input = jax.random.uniform(
    key=jax.random.PRNGKey(1),
    shape=(sequence_length, batch_size, input_size),
    dtype=jnp.float32,
    minval=-0.8,
    maxval=0.8
  )
  output1, _ = model.apply(variables, input)
  output2, _ = jax.jit(model.apply)(variables, input)
  print(jnp.array_equal(output1, output2))

if __name__ == '__main__':
  main()

got result: False (jax version: 0.4.8, jaxlib: 0.4.7+cuda12.cudnn88)

Is this a bug? Thanks. (By the way, why in this case, I donot need to declare static_argnums?)

jakevdp commented 5 months ago

Thanks for the clear reproduction! That looks like a bug. In general we would not expect the results to be bitwise equivalent (jit rearranges floating point operations, so floating point errors accumulate differently) but the results here differ by far more than would be expected for just floating point error.

By the way, why in this case, I donot need to declare static_argnums?

Yes: arguments only need to be marked static at the jit boundary. Here model.apply doesn't take any static arguments, so there's no need to specify static_argnums.

cgarciae commented 5 months ago

I'll try to finish google/flax#2971 to provide a more streamlined experience.

Sun-Xiaohui commented 3 months ago

Hello, JAX team and Flax team, is fixing this bug in your latest work plan?

knightXun commented 1 month ago

@jakevdp

Thanks for the clear reproduction! That looks like a bug. In general we would not expect the results to be bitwise equivalent (jit rearranges floating point operations, so floating point errors accumulate differently) but the results here differ by far more than would be expected for just floating point error.

By the way, why in this case, I donot need to declare static_argnums?

Yes: arguments only need to be marked static at the jit boundary. Here model.apply doesn't take any static arguments, so there's no need to specify static_argnums.

I rewrite thie code and fully test it: lstm has no problem. The issue may be caused by the computation graph generated by XLA. How should I continue to investigate? Which part of the XLA source code should I review?

import logging
from flax.linen import Module
import jax.experimental.rnn as rnn
import jax.numpy as jnp
import jax
from jax import random

from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir

@jax.default_matmul_precision("float32")
def main():
    sequence_length = 28
    batch_size = 512
    input_size = 28
    hidden_size = 128
    batch_first = False
    num_layers = 1
    bidirectional = False

    input_seq = jax.random.uniform(
        key=jax.random.PRNGKey(1),
        shape=(sequence_length, batch_size, input_size),
        dtype=jnp.float32,
        minval=-0.8,
        maxval=0.8
    )

    batch_size = input_seq.shape[0] if batch_first else input_seq.shape[1]
    num_direction = 2 if bidirectional else 1

    h_0 = jnp.zeros(shape=(num_layers * num_direction,
                           batch_size, hidden_size), dtype=input_seq.dtype)

    c_0 = jnp.zeros(shape=(num_layers * num_direction,
                           batch_size, hidden_size), dtype=input_seq.dtype)

    if not batch_first:
        input_seq = jnp.moveaxis(input_seq, 0, 1)

    seq_lengths = jnp.full(
        (batch_size,), input_seq.shape[1], dtype=jnp.int32)

    key = random.PRNGKey(0)
    w = random.uniform(key, (80896,), minval=0, maxval=1)

    output, h_n, c_n = rnn.lstm(
        input_seq,
        h_0,
        c_0,
        w,
        seq_lengths,
        input_size,
        hidden_size,
        num_layers,
        False,
        bidirectional
    )

    jit_output, jit_h_n, jit_c_n = jax.jit(rnn.lstm, static_argnames=['input_size', 'hidden_size', 'num_layers', 'dropout', 'bidirectional', 'precision'])(
        input_seq,
        h_0,
        c_0,
        w,
        seq_lengths,
        input_size,
        hidden_size,
        num_layers,
        False,
        bidirectional
    )

    print(jnp.allclose(output, jit_output))
    print(jnp.allclose(h_n, jit_h_n))
    print(jnp.allclose(c_n, jit_c_n))

if __name__ == '__main__':
    main()
knightXun commented 1 month ago

@jakevdp It seems that an error occurred when XLA was processing the jnp.moveaxis operator. I changed the code to the following, and the result is correct

from flax.linen import Module
import jax.experimental.rnn as rnn
import jax.numpy as jnp
import jax

class LSTM(Module):

    input_size: int
    hidden_size: int
    num_layers: int = 1
    batch_first: bool = True
    bidirectional: bool = False

    def setup(self):
        self.w = self.param(
            'lstm_weights',
            rnn.init_lstm_weight,
            self.input_size,
            self.hidden_size,
            self.num_layers, self.bidirectional
        )

    def __call__(self, input_seq):
        batch_size = input_seq.shape[0] if self.batch_first else input_seq.shape[1]
        num_direction = 2 if self.bidirectional else 1
        h_0 = jnp.zeros(shape=(self.num_layers * num_direction,
                        batch_size, self.hidden_size), dtype=input_seq.dtype)
        c_0 = jnp.zeros(shape=(self.num_layers * num_direction,
                        batch_size, self.hidden_size), dtype=input_seq.dtype)
        # if not self.batch_first:
        #     input_seq = jnp.moveaxis(input_seq, 0, 1)
        seq_lengths = jnp.full(
            (batch_size,), input_seq.shape[1], dtype=jnp.int32)
        output, h_n, c_n = rnn.lstm(
            input_seq,
            h_0,
            c_0,
            self.w,
            seq_lengths,
            self.input_size,
            self.hidden_size,
            self.num_layers,
            False,
            self.bidirectional
        )
        # if not self.batch_first:
        #     output = jnp.moveaxis(output, 0, 1)
        return output, (h_n, c_n)

def main():
    sequence_length = 28
    batch_size = 512
    input_size = 28
    hidden_size = 128

    model = LSTM(input_size=input_size,
                 hidden_size=hidden_size, batch_first=False)
    variables = model.init(
        jax.random.PRNGKey(1),
        jnp.ones([sequence_length, batch_size, input_size], dtype=jnp.float32)
    )
    input = jax.random.uniform(
        key=jax.random.PRNGKey(1),
        shape=(sequence_length, batch_size, input_size),
        dtype=jnp.float32,
        minval=-0.8,
        maxval=0.8
    )
    output1, _ = model.apply(variables, input)
    output2, _ = jax.jit(model.apply)(variables, input)
    print(jnp.array_equal(output1, output2))

if __name__ == '__main__':
    main()
knightXun commented 1 month ago
   %13 = stablehlo.transpose %8, dims = [1, 0, 2] : (tensor<512x28x128xf32>) -> tensor<28x512x128xf32>

It is highly likely that jnp.moveaxis and stablehlo.transpose are not equivalent.

jakevdp commented 1 month ago

It is highly likely that jnp.moveaxis and stablehlo.transpose are not equivalent.

moveaxis is just another way of writing transpose: https://github.com/google/jax/blob/e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b/jax/_src/numpy/lax_numpy.py#L1371-L1381