Open Totom3 opened 3 years ago
It appears that any model which contains a ReversibleLayer
(or subclass of) will produce such an error. The following code produces the same errors. The precise error depends on whether:
DoSomething
is declared a subclass of tl.Layer
or tl.ReversibleLayer
PureReversible
is declared a subclass of tl.Layer
or tl.ReversibleLayer
tl.Serial
or tl.ReversibleSerial
.If all 3 use the non-reversible options, everything works. If any of the 3 use the reversible option, it produces an error similar to the above.
import trax
import trax.data as td
import trax.layers as tl
import trax.supervised.training as tt
from trax.fastmath import numpy as jnp
class DoSomething(tl.ReversibleLayer):
def __init__(self):
super().__init__(n_in=2, n_out=2)
self.l = tl.Dense(10)
self._sublayers = (self.l,)
def forward(self, x):
x1, x2 = x
return x1 + self.l(x2), x2
def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
y1, y2 = output
return y1 - self.l(y2), y2
def init_weights_and_state(self, sig):
w, s = self.l.init(sig[0])
self.weights = (w,)
self.state = (s,)
class PureReversible(tl.ReversibleLayer):
def __init__(self, forw, backw, n_in, n_out):
self.forw = forw
self.backw = backw
super().__init__(n_in=n_in, n_out=n_out)
def forward(self, x):
return self.forw(x)
def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
return self.backw(output)
split = tl.Fn("Split Stack", lambda x: jnp.split(x, 2), n_out=2)
merge = tl.Fn("Merge Stack", lambda x1, x2: jnp.concatenate([x1, x2]), n_out=1)
rev_split = PureReversible(split, merge, 1, 2)
rev_merge = PureReversible(merge, split, 2, 1)
def input_stream0(_=None):
while True:
yield (jnp.zeros((20,)), jnp.zeros((20,)))
model = tl.ReversibleSerial(rev_split, DoSomething(), rev_merge)
train_task = tt.TrainTask(
labeled_data = td.Serial(input_stream0, td.AddLossWeights())(),
loss_layer = tl.L2Loss(),
optimizer = trax.optimizers.Adam(0.01))
training_loop = tt.Loop(model, train_task)
training_loop.run(1)
If there is any known workaround that does not involve completely dropping reversible layers, we would really appreciate a temporary solution, as we depend on the memory savings incurred by reversible networks. Thank you!
Did you try to install Trax
from the master branch?
No, I got it from pip. However in the end I managed to make it work by passing use_memory_efficient_trainer=True
as an argument to Loop
. (I'm not sure if this should be closed so I'm leaving it as is, but I'm satisfied with the solution I found)
Description
ReversibleSelect
seems to mess up JAX's backtracing/JIT compilation. In the code provided below, we define a simple (non reversible) model which (1) splits the inputs; (2) does something to one input; (3) swaps the two inputs; (4) merges the inputs. Step 3 usesReversibleSelect
. The resulting model can be initialized and called, but will cause errors when attempting to train it withTrainTask
. Curiously enough:ReversibleSelect
bySelect
, there is no error.ReversibleSelect
by a pure function which manually swaps the inputs, there is no error. By "pure function" we mean something liketl.Fn("Swap", lambda a, b: (b, a), n_out=2)
.Finally, in case this is important, the code was run on a machine without a GPU or TPU.
Environment information
For bugs: reproduction and error logs
To reproduce, run the following code
Error logs:
Below is the full error log. I don't think it's important, but in case it is, there is also a warning about no GPU/TPU, some information about tensorflow when trax is first loaded, and a warning about the missing
output_dir
parameter inLoop
.