google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.07k stars 813 forks source link

LSHSelfAttention only works with different lengths if the parameter use_reference_code is set to "True" #1664

Open renevs opened 3 years ago

renevs commented 3 years ago

Description

LSHSelfAttention only works with different lengths, after initiated, if the parameter use_reference_code is set to "True". So I cant use the LSHSelfAttention in a Reformer Model with BucketByLength.

Environment information

OS: Ubuntu 18.04.1 LTS

$ pip freeze | grep trax
trax==1.3.9

$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.3.0
tensorflow-estimator==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.0.0
tensorflow-text==2.5.0

$ pip freeze | grep jax
jax==0.2.16
jaxlib==0.1.67
jupyter-server-mathjax==0.2.2

$ python -V
Python 3.8.10

For bugs: reproduction and error logs

# Steps to reproduce:
import trax
from trax import layers as tl
import jax.numpy as jnp

shapedtype = trax.shapes.ShapeDtype((4,32,512), dtype=jnp.int32)
# g = tl.SelfAttention( mode='train ')
g = tl.LSHSelfAttention(chunk_len=8,
                    use_reference_code=False, 
                    mode = 'train')

g.init(shapedtype)
valores = jnp.ones((8,16,512))
g(valores)

Error logs:


LayerError Traceback (most recent call last)

in 12 g.init(shapedtype) 13 valores = jnp.ones((8,16,512)) ---> 14 g(valores) ~/anaconda3/envs/ambiente_ipdr/lib/python3.8/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng) 195 self.state = state # Needed if the model wasn't fully initialized. 196 state = self.state --> 197 outputs, new_state = self.pure_fn(x, weights, state, rng) 198 self.state = new_state 199 return outputs ~/anaconda3/envs/ambiente_ipdr/lib/python3.8/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache) 603 # Skipping 3 lines as it's always the uninteresting internal call. 604 name, trace = self._name, _short_traceback(skip=3) --> 605 raise LayerError(name, 'pure_fn', 606 self._caller, signature(x), trace) from None 607 LayerError: Exception passing through layer LSHSelfAttention (in pure_fn): layer created in file [...]/layers/research/efficient_attention.py, line 1744 layer input shapes: ShapeDtype{shape:(8, 16, 512), dtype:float32} File [...]/trax/layers/base.py, line 673, in _do_custom_gradients output, state = do_forward(self.state, self._rng, x, self.weights) File [...]/jax/_src/custom_derivatives.py, line 486, in __call__ out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, File [...]/jax/_src/custom_derivatives.py, line 566, in bind outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, File [...]/site-packages/jax/core.py, line 617, in process_custom_vjp_call return fun.call_wrapped(*tracers) File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File [...]/trax/fastmath/jax.py, line 167, in _f return f(*args, **kwargs) File [...]/trax/layers/base.py, line 651, in _f res = self.forward(y) File [...]/layers/research/efficient_attention.py, line 2117, in forward output, new_state, _, _ = self.forward_and_or_backward( File [...]/layers/research/efficient_attention.py, line 2538, in forward_and_or_backward loop_val = fastmath.fori_loop( File [...]/trax/fastmath/ops.py, line 173, in fori_loop return backend()['fori_loop'](lower, upper, body_fn, init_val) File [...]/jax/_src/traceback_util.py, line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File [...]/_src/lax/control_flow.py, line 212, in fori_loop (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val), File [...]/jax/_src/traceback_util.py, line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File [...]/_src/lax/control_flow.py, line 1288, in scan init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) File [...]/_src/lax/control_flow.py, line 1274, in _create_jaxpr jaxpr, consts, out_tree = _initial_style_jaxpr( File [...]/jax/_src/util.py, line 186, in wrapper return cached(config._trace_context(), *args, **kwargs) File [...]/jax/_src/util.py, line 179, in cached return f(*args, **kwargs) File [...]/_src/lax/control_flow.py, line 76, in _initial_style_jaxpr jaxpr, consts, out_tree = _initial_style_open_jaxpr( File [...]/jax/_src/util.py, line 186, in wrapper return cached(config._trace_context(), *args, **kwargs) File [...]/jax/_src/util.py, line 179, in cached return f(*args, **kwargs) File [...]/_src/lax/control_flow.py, line 70, in _initial_style_open_jaxpr jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) File [...]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File [...]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers) File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File [...]/_src/lax/control_flow.py, line 143, in scanned_fun return (i + 1, body_fun(i, x)), None File [...]/layers/research/efficient_attention.py, line 2421, in run_inner s_all = tree_update(s_all, idx, s_h) File [...]/layers/research/efficient_attention.py, line 2352, in tree_update return fastmath.nested_map_multiarg( File [...]/trax/fastmath/numpy.py, line 136, in nested_map_multiarg return tuple([nested_map_multiarg(f, *[o[i] for o in objs]) File [...]/trax/fastmath/numpy.py, line 136, in return tuple([nested_map_multiarg(f, *[o[i] for o in objs]) File [...]/trax/fastmath/numpy.py, line 143, in nested_map_multiarg return f(*objs) File [...]/layers/research/efficient_attention.py, line 2353, in lambda x, y: fastmath.index_update(x, jax.ops.index[indices], y), File [...]/trax/fastmath/ops.py, line 199, in index_update return backend()['index_update'](*args, **kwargs) File [...]/_src/ops/scatter.py, line 351, in index_update return _scatter_update( File [...]/_src/ops/scatter.py, line 68, in _scatter_update return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, File [...]/_src/ops/scatter.py, line 90, in _scatter_impl y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) File [...]/_src/numpy/lax_numpy.py, line 1816, in broadcast_to raise ValueError(msg.format(arr_shape, shape)) ValueError: Incompatible shapes for broadcasting: (16,) and requested shape (32,) ```