Open Sun-Xiaohui opened 9 months ago
Hi - the lstm
function accepts a number of arguments that must be static (e.g. array sizes, boolean flags, etc). If you are wrapping it in jit
, then you should declare these using static_argnames
or static_argnums
. It might look something like this:
out = jax.jit(rnn.lstm, static_argnames=['input_size', 'hidden_size', 'num_layers', 'dropout', 'bidirectional', 'precision'])
@jakevdp Thanks, that helps! However, I found that if I have a Module
developed based on rnn.lstm
, forward value between model.apply()
and jax.jit(model.apply)()
does not match, here is a minimal reproducible example
from flax.linen import Module
import jax.experimental.rnn as rnn
import jax.numpy as jnp
import jax
class LSTM(Module):
input_size: int
hidden_size: int
num_layers: int = 1
batch_first: bool = True
bidirectional: bool = False
def setup(self):
self.w = self.param(
'lstm_weights',
rnn.init_lstm_weight,
self.input_size,
self.hidden_size,
self.num_layers, self.bidirectional
)
def __call__(self, input_seq):
batch_size = input_seq.shape[0] if self.batch_first else input_seq.shape[1]
num_direction = 2 if self.bidirectional else 1
h_0 = jnp.zeros(shape=(self.num_layers * num_direction, batch_size, self.hidden_size), dtype=input_seq.dtype)
c_0 = jnp.zeros(shape=(self.num_layers * num_direction, batch_size, self.hidden_size), dtype=input_seq.dtype)
if not self.batch_first:
input_seq = jnp.moveaxis(input_seq, 0, 1)
seq_lengths = jnp.full((batch_size,), input_seq.shape[1], dtype=jnp.int32)
output, h_n, c_n = rnn.lstm(
input_seq,
h_0,
c_0,
self.w,
seq_lengths,
self.input_size,
self.hidden_size,
self.num_layers,
False,
self.bidirectional
)
if not self.batch_first:
output = jnp.moveaxis(output, 0, 1)
return output, (h_n, c_n)
def main():
sequence_length = 28
batch_size = 512
input_size = 28
hidden_size = 128
model = LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=False)
variables = model.init(
jax.random.PRNGKey(1),
jnp.ones([sequence_length, batch_size, input_size], dtype=jnp.float32)
)
input = jax.random.uniform(
key=jax.random.PRNGKey(1),
shape=(sequence_length, batch_size, input_size),
dtype=jnp.float32,
minval=-0.8,
maxval=0.8
)
output1, _ = model.apply(variables, input)
output2, _ = jax.jit(model.apply)(variables, input)
print(jnp.array_equal(output1, output2))
if __name__ == '__main__':
main()
got result: False (jax version: 0.4.8, jaxlib: 0.4.7+cuda12.cudnn88)
Is this a bug? Thanks. (By the way, why in this case, I donot need to declare static_argnums
?)
Thanks for the clear reproduction! That looks like a bug. In general we would not expect the results to be bitwise equivalent (jit
rearranges floating point operations, so floating point errors accumulate differently) but the results here differ by far more than would be expected for just floating point error.
By the way, why in this case, I donot need to declare static_argnums?
Yes: arguments only need to be marked static at the jit
boundary. Here model.apply
doesn't take any static arguments, so there's no need to specify static_argnums
.
I'll try to finish google/flax#2971 to provide a more streamlined experience.
Hello, JAX team and Flax team, is fixing this bug in your latest work plan?
@jakevdp
Thanks for the clear reproduction! That looks like a bug. In general we would not expect the results to be bitwise equivalent (
jit
rearranges floating point operations, so floating point errors accumulate differently) but the results here differ by far more than would be expected for just floating point error.By the way, why in this case, I donot need to declare static_argnums?
Yes: arguments only need to be marked static at the
jit
boundary. Heremodel.apply
doesn't take any static arguments, so there's no need to specifystatic_argnums
.
I rewrite thie code and fully test it: lstm has no problem. The issue may be caused by the computation graph generated by XLA. How should I continue to investigate? Which part of the XLA source code should I review?
import logging
from flax.linen import Module
import jax.experimental.rnn as rnn
import jax.numpy as jnp
import jax
from jax import random
from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir
@jax.default_matmul_precision("float32")
def main():
sequence_length = 28
batch_size = 512
input_size = 28
hidden_size = 128
batch_first = False
num_layers = 1
bidirectional = False
input_seq = jax.random.uniform(
key=jax.random.PRNGKey(1),
shape=(sequence_length, batch_size, input_size),
dtype=jnp.float32,
minval=-0.8,
maxval=0.8
)
batch_size = input_seq.shape[0] if batch_first else input_seq.shape[1]
num_direction = 2 if bidirectional else 1
h_0 = jnp.zeros(shape=(num_layers * num_direction,
batch_size, hidden_size), dtype=input_seq.dtype)
c_0 = jnp.zeros(shape=(num_layers * num_direction,
batch_size, hidden_size), dtype=input_seq.dtype)
if not batch_first:
input_seq = jnp.moveaxis(input_seq, 0, 1)
seq_lengths = jnp.full(
(batch_size,), input_seq.shape[1], dtype=jnp.int32)
key = random.PRNGKey(0)
w = random.uniform(key, (80896,), minval=0, maxval=1)
output, h_n, c_n = rnn.lstm(
input_seq,
h_0,
c_0,
w,
seq_lengths,
input_size,
hidden_size,
num_layers,
False,
bidirectional
)
jit_output, jit_h_n, jit_c_n = jax.jit(rnn.lstm, static_argnames=['input_size', 'hidden_size', 'num_layers', 'dropout', 'bidirectional', 'precision'])(
input_seq,
h_0,
c_0,
w,
seq_lengths,
input_size,
hidden_size,
num_layers,
False,
bidirectional
)
print(jnp.allclose(output, jit_output))
print(jnp.allclose(h_n, jit_h_n))
print(jnp.allclose(c_n, jit_c_n))
if __name__ == '__main__':
main()
@jakevdp It seems that an error occurred when XLA was processing the jnp.moveaxis operator. I changed the code to the following, and the result is correct
from flax.linen import Module
import jax.experimental.rnn as rnn
import jax.numpy as jnp
import jax
class LSTM(Module):
input_size: int
hidden_size: int
num_layers: int = 1
batch_first: bool = True
bidirectional: bool = False
def setup(self):
self.w = self.param(
'lstm_weights',
rnn.init_lstm_weight,
self.input_size,
self.hidden_size,
self.num_layers, self.bidirectional
)
def __call__(self, input_seq):
batch_size = input_seq.shape[0] if self.batch_first else input_seq.shape[1]
num_direction = 2 if self.bidirectional else 1
h_0 = jnp.zeros(shape=(self.num_layers * num_direction,
batch_size, self.hidden_size), dtype=input_seq.dtype)
c_0 = jnp.zeros(shape=(self.num_layers * num_direction,
batch_size, self.hidden_size), dtype=input_seq.dtype)
# if not self.batch_first:
# input_seq = jnp.moveaxis(input_seq, 0, 1)
seq_lengths = jnp.full(
(batch_size,), input_seq.shape[1], dtype=jnp.int32)
output, h_n, c_n = rnn.lstm(
input_seq,
h_0,
c_0,
self.w,
seq_lengths,
self.input_size,
self.hidden_size,
self.num_layers,
False,
self.bidirectional
)
# if not self.batch_first:
# output = jnp.moveaxis(output, 0, 1)
return output, (h_n, c_n)
def main():
sequence_length = 28
batch_size = 512
input_size = 28
hidden_size = 128
model = LSTM(input_size=input_size,
hidden_size=hidden_size, batch_first=False)
variables = model.init(
jax.random.PRNGKey(1),
jnp.ones([sequence_length, batch_size, input_size], dtype=jnp.float32)
)
input = jax.random.uniform(
key=jax.random.PRNGKey(1),
shape=(sequence_length, batch_size, input_size),
dtype=jnp.float32,
minval=-0.8,
maxval=0.8
)
output1, _ = model.apply(variables, input)
output2, _ = jax.jit(model.apply)(variables, input)
print(jnp.array_equal(output1, output2))
if __name__ == '__main__':
main()
%13 = stablehlo.transpose %8, dims = [1, 0, 2] : (tensor<512x28x128xf32>) -> tensor<28x512x128xf32>
It is highly likely that jnp.moveaxis
and stablehlo.transpose
are not equivalent.
It is highly likely that
jnp.moveaxis
andstablehlo.transpose
are not equivalent.
moveaxis
is just another way of writing transpose
: https://github.com/google/jax/blob/e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b/jax/_src/numpy/lax_numpy.py#L1371-L1381
The above code reports an error:
jax._src.errors.UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums.
How can
rnn.lstm()
be jited? Thanks a lot!JAX version : 0.4.8 jaxlib : 0.4.7+cuda12.cudnn88