google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.1k stars 816 forks source link

[Bug] ReversibleSelect causes error when training model #1295

Open Totom3 opened 3 years ago

Totom3 commented 3 years ago

Description

ReversibleSelect seems to mess up JAX's backtracing/JIT compilation. In the code provided below, we define a simple (non reversible) model which (1) splits the inputs; (2) does something to one input; (3) swaps the two inputs; (4) merges the inputs. Step 3 uses ReversibleSelect. The resulting model can be initialized and called, but will cause errors when attempting to train it with TrainTask. Curiously enough:

Finally, in case this is important, the code was run on a machine without a GPU or TPU.

Environment information

OS: Linux 4.9.0-13-amd64 #1 SMP Debian 4.9.228-1 (2020-07-05) x86_64 GNU/Linux

$ pip freeze | grep trax
trax==1.3.6

$ pip freeze | grep tensor
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.0
tensorflow-addons==0.11.2
tensorflow-data-validation==0.22.2
tensorflow-datasets==2.0.0
tensorflow-enterprise-addons @ file:///opt/conda/conda-bld/dlenv-tf-2-1-cpu_1598328292311/work/tensorflow_enterprise_addons-0.0.0-py3-none-any.whl
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-io==0.11.0
tensorflow-metadata==0.22.2
tensorflow-model-analysis==0.22.2
tensorflow-probability==0.7.0
tensorflow-serving-api==2.1.0
tensorflow-text==2.3.0
tensorflow-transform==0.22.0

$ pip freeze | grep jax
jax==0.1.75
jaxlib==0.1.52

$ python -V
Python 3.7.8

For bugs: reproduction and error logs

To reproduce, run the following code

import trax
import trax.data as td
import trax.layers as tl
import trax.supervised.training as tt
from trax.fastmath import numpy as jnp
trax.fastmath.use_backend('jax')

split_stack = tl.Fn("Split Stack", lambda x: jnp.split(x, 2), n_out=2)
merge_stack = tl.Fn("Merge Stack", lambda x1, x2: jnp.concatenate([x1, x2]), n_out=1)

inputs_size = 20
def input_stream0(_=None):
    while True:
        yield (jnp.zeros((inputs_size,)), jnp.zeros((inputs_size,)))

# If one replaces ReversibleSelect by Select, or by a manual swap of inputs, everything works!
model = tl.Serial(split_stack, tl.Dense(inputs_size//2), tl.ReversibleSelect([1, 0]), merge_stack)
in_stream = lambda: td.Serial(input_stream0, td.AddLossWeights())()

train_task = tt.TrainTask(
    labeled_data = in_stream(),
    loss_layer = tl.L2Loss(),
    optimizer = trax.optimizers.Adam(0.01))

training_loop = tt.Loop(model, train_task)
training_loop.run(1)

Error logs:

Below is the full error log. I don't think it's important, but in case it is, there is also a warning about no GPU/TPU, some information about tensorflow when trax is first loaded, and a warning about the missing output_dir parameter in Loop.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-2a3851881ca5> in <module>
     24 
     25 training_loop = tt.Loop(model, train_task)
---> 26 training_loop.run(1)

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    337         task_index = self._which_task(self._step)
    338         task_changed = task_index != prev_task_index
--> 339         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    340 
    341         # optimizer_metrics and loss are replicated on self.n_devices, a few

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    450       trainer.accelerated_loss_layer.replicate_weights(model.weights)
    451       trainer.accelerated_loss_layer.replicate_state(model.state)
--> 452     return trainer.one_step(batch, rng, step=step, learning_rate=learning_rate)
    453 
    454   def _log_training_progress(self, task, total_loss, n_steps, elapsed_time,

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    129     # NOTE: stats is a replicated dictionary of key to jnp arrays.
    130     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 131         (weights, self._slots), step, self._opt_params, batch, state, rng)
    132 
    133     if logging.vlog_is_on(1) and ((step & step - 1) == 0):

/opt/conda/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    169     flat_fun, out_tree = flatten_fun(f, in_tree)
    170     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 171                        name=flat_fun.__name__, donated_invars=donated_invars)
    172     return tree_unflatten(out_tree(), out)
    173 

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1132 
   1133   def bind(self, fun, *args, **params):
-> 1134     return call_bind(self, fun, *args, **params)
   1135 
   1136   def process(self, trace, fun, tracers, params):

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1121   if top_trace is None:
   1122     with new_sublevel():
-> 1123       outs = primitive.impl(fun, *args, **params)
   1124   else:
   1125     tracers = map(top_trace.full_raise, args)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    525 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
    526   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 527                                *unsafe_map(arg_spec, args))
    528   try:
    529     return compiled_fun(*args)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    222       fun.populate_stores(stores)
    223     else:
--> 224       ans = call(fun, *args)
    225       cache[key] = (ans, fun.stores)
    226     return ans

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    596     pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
    597     jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 598         fun, pvals, instantiate=False, stage_out=True, bottom=True)
    599   map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
    600   jaxpr = apply_outfeed_rewriter(jaxpr)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    421   with core.new_master(trace_type, bottom=bottom) as master:
    422     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 423     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    424     assert not env
    425     del master

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    173       weights, slots = weights_and_slots
    174       (loss, state), gradients = forward_and_backward_fn(
--> 175           batch, weights, state, rng)
    176       weights, slots, stats = optimizer.tree_update(
    177           step, gradients, weights, slots, opt_params)

/opt/conda/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    491     dtype = dtypes.result_type(ans)
    492     tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
--> 493     g = vjp_py(np.ones((), dtype=dtype))
    494     g = g[0] if isinstance(argnums, int) else g
    495     if not has_aux:

/opt/conda/lib/python3.7/site-packages/jax/api.py in _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args)
   1458              "match type of corresponding primal output ({})")
   1459       raise TypeError(msg.format(_dtype(a), dtype))
-> 1460   ans = fun(*args)
   1461   return tree_unflatten(out_tree, ans)
   1462 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in unbound_vjp(pvals, jaxpr, consts, *cts)
    115     cts = tuple(map(ignore_consts, cts, pvals))
    116     dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
--> 117     arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
    118     return map(instantiate_zeros, arg_cts)
    119 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
    204       else:
    205         cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
--> 206                                                          **eqn.params)
    207     cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
    208     # FIXME: Some invars correspond to primals!

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in _custom_lin_transpose(cts_out, num_res, bwd, avals_out, *invals)
    609   res, _ = split_list(invals, [num_res])
    610   cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
--> 611   cts_in = bwd.call_wrapped(*res, *cts_out)
    612   cts_in_flat, _ = tree_flatten(cts_in)  # already checked tree structure
    613   return [None] * num_res + cts_in_flat

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(***failed resolving arguments***)
    152     while stack:
    153       gen, out_store = stack.pop()
--> 154       ans = gen.send(ans)
    155       if out_store is not None:
    156         ans, side = ans

/opt/conda/lib/python3.7/site-packages/jax/custom_derivatives.py in _flatten_bwd(in_tree, out_trees, *args)
    510            "number of arguments to the primal function, but got VJP output "
    511            "structure {} for primal input structure {}.")
--> 512     raise TypeError(msg.format(in_tree2, in_tree)) from None
    513   yield cts_in
    514 

TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function, and in particular must produce a tuple of length equal to the number of arguments to the primal function, but got VJP output structure PyTreeDef(tuple, [PyTreeDef(None, []),PyTreeDef(None, []),PyTreeDef(tuple, [*,*]),PyTreeDef(tuple, [])]) for primal input structure PyTreeDef(tuple, [PyTreeDef(tuple, []),*,PyTreeDef(tuple, [*,*]),PyTreeDef(tuple, [])]).
Totom3 commented 3 years ago

Update

It appears that any model which contains a ReversibleLayer (or subclass of) will produce such an error. The following code produces the same errors. The precise error depends on whether:

  1. ...my class DoSomething is declared a subclass of tl.Layer or tl.ReversibleLayer
  2. ...my class PureReversible is declared a subclass of tl.Layer or tl.ReversibleLayer
  3. ...the whole model is assembled with tl.Serial or tl.ReversibleSerial.

If all 3 use the non-reversible options, everything works. If any of the 3 use the reversible option, it produces an error similar to the above.

import trax
import trax.data as td
import trax.layers as tl
import trax.supervised.training as tt
from trax.fastmath import numpy as jnp

class DoSomething(tl.ReversibleLayer):
    def __init__(self):
        super().__init__(n_in=2, n_out=2)
        self.l = tl.Dense(10)
        self._sublayers = (self.l,)

    def forward(self, x):
        x1, x2 = x
        return x1 + self.l(x2), x2

    def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
        y1, y2 = output
        return y1 - self.l(y2), y2

    def init_weights_and_state(self, sig):
        w, s = self.l.init(sig[0])

        self.weights = (w,)
        self.state = (s,)

class PureReversible(tl.ReversibleLayer):
    def __init__(self, forw, backw, n_in, n_out):
        self.forw = forw
        self.backw = backw
        super().__init__(n_in=n_in, n_out=n_out)

    def forward(self, x):
        return self.forw(x)

    def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
        return self.backw(output)

split = tl.Fn("Split Stack", lambda x: jnp.split(x, 2), n_out=2)
merge = tl.Fn("Merge Stack", lambda x1, x2: jnp.concatenate([x1, x2]), n_out=1)

rev_split = PureReversible(split, merge, 1, 2)
rev_merge = PureReversible(merge, split, 2, 1)

def input_stream0(_=None):
    while True:
        yield (jnp.zeros((20,)), jnp.zeros((20,)))

model = tl.ReversibleSerial(rev_split, DoSomething(), rev_merge)

train_task = tt.TrainTask(
    labeled_data = td.Serial(input_stream0, td.AddLossWeights())(),
    loss_layer = tl.L2Loss(),
    optimizer = trax.optimizers.Adam(0.01))

training_loop = tt.Loop(model, train_task)
training_loop.run(1)

If there is any known workaround that does not involve completely dropping reversible layers, we would really appreciate a temporary solution, as we depend on the memory savings incurred by reversible networks. Thank you!

thoo commented 3 years ago

Did you try to install Trax from the master branch?

Totom3 commented 3 years ago

No, I got it from pip. However in the end I managed to make it work by passing use_memory_efficient_trainer=True as an argument to Loop. (I'm not sure if this should be closed so I'm leaving it as is, but I'm satisfied with the solution I found)