Closed albertz closed 3 years ago
The example should be extended:
How are nested loops handled? This should be straight-forward as well. And nothing must be ambiguous about it, e.g. the (implicit or explicit) definition and handling of hidden state.
What are reasonable sub modules with hidden state? How would you define them? Just normal Module
s? Or other Rec
s where you call step
?
Custom ending condition, how to define?
How to get the last hidden state?
Also related: https://github.com/rwth-i6/returnn/issues/391 https://github.com/rwth-i6/returnn/pull/545
I.e. the question about how to define masked self-attention in a generic/flexible way in this Rec
concept. E.g. HistoryMaskLayer
or so. This could all be wrapped directly but is this intuitive then? How does it look like? Can we make it more intuitive?
Also: Is there anything special needed for Choice
(ChoiceLayer
)? And beam search? Maybe not and it is all straight-forward already.
How to define whether search is enabled? Just not do any special handling on this level at all and leave it to the current RETURNN behavior? Or make it more explicit?
How to define whether search is enabled? Just not do any special handling on this level at all and leave it to the current RETURNN behavior? Or make it more explicit?
From what I can see RETURNN handling should be enough. Assuming using this with automation changes in config can be done elsewhere I think.
Custom ending condition, how to define?
Maybe give an optional parameter for Rec
like Rec.custom_break
which will be checked internally at the beginning of the loop?
self.h = self.state({batch, dim})
How do you concretely think of this init? This requires an explicit value for the batch dimension right? I am not sure if this is what we should aim for, RETURNN should be able to handle this already? Or is this right now just for example sake?
with rec.loop() as loop: # this introduces a new loop h_ = rec(x) # shape {batch,dim}. this represents the inner value h = loop.last(h_) # shape {batch,dim} out = loop.stack(h_) # shape {time,batch,dim}
Maybe this example (esp. for the doc.) could be extended with another (non-rec) module, showing how to use it in a Module, so that someone who did not work with Rec in this way yet directly has an example which can be modified. Maybe something like (assuming this is how its intended):
def EncDec(Module):
def __init__():
self.enc = Linear(dim)
self.dec = MyRec()
def forward(x):
x = self.enc(x)
with self.dec.loop():
h_ = rec(x) # shape {batch,dim}. this represents the inner value
h = loop.last(h_) # shape {batch,dim}
out = loop.stack(h_) # shape {time,batch,dim}
return out
Also: Is there anything special needed for Choice (ChoiceLayer)? And beam search? Maybe not and it is all straight-forward already.
We are mainly aiming for conversion to RETURNN here, so maybe if we choose a smart way to include them in the loop and then some internal handling this can be converted (since in returnn they are part of the rec unit anyways).
How to define whether search is enabled? Just not do any special handling on this level at all and leave it to the current RETURNN behavior? Or make it more explicit?
From what I can see RETURNN handling should be enough. Assuming using this with automation changes in config can be done elsewhere I think.
Note that this is not just about what is enough to be able to define all networks. Or not sure how you mean it.
It's about being straight forward and clear. I.e. at no time, it should be unclear to the user when search is used.
We do not have to follow exactly the behavior of RETURNN. There are also multiple ways in RETURNN. We can restrict it to one clean way. We can also change it. Or introduce a simpler variant here.
I'm tending to make it explicit. But not sure.
PyTorch also has a similar concept for the train flag (as we do have as well in RETURNN). I.e. some PyTorch modules behave differently depending if they are in train or eval mode (e.g. Dropout
). We have exactly the same in RETURNN. And search is a flag like train.
The difference is how these flags are set:
In RETURNN, this is all globally, and for search flag, there are some additional (maybe unintuitive) ways to overwrite it. And the flags are implied automatically in RETURNN, depending e.g. on the task, and the user has not much control over it. It is quite hidden.
In PyTorch, there are no implicit automatic implied global flags. Every module has its own flag, and it is set explicitly (and easily recursively for all sub modules). Every module has always the train flag set initially, and you can disable it explicitly. So to the user, it's always clear how the flags are set, because the user sets them, and no automatic behavior. The user explicitly writes model.train()
or model.eval()
.
Maybe again, here in returnn-common, we can follow the PyTorch style a bit for this, and also copy it for the search flag? Not sure...
Edit: I think this is actually a separate thing, which we can discuss/handle independently. So I opened #18 about this. And this can be marked as off-topic here.
Custom ending condition, how to define?
Maybe give an optional parameter for
Rec
likeRec.custom_break
which will be checked internally at the beginning of the loop?
In RETURNN, the layer is called end
. So maybe additionally to step
, the user can define a method end
(kind of similar as tf.while_loop
gets body
and cond
).
This end
method can be optional. E.g. when unstack
is used at some point, then the sequence length is implied from this. The same is true when cross-entropy loss is used, as unstack
would be used on the targets.
self.h = self.state({batch, dim})
How do you concretely think of this init? This requires an explicit value for the batch dimension right? I am not sure if this is what we should aim for, RETURNN should be able to handle this already? Or is this right now just for example sake?
batch
here would be the dimension tag (#17). I.e. this is kind of a placeholder. This does not require an explicit value. And yes, RETURNN has all that logic already.
with rec.loop() as loop: # this introduces a new loop h_ = rec(x) # shape {batch,dim}. this represents the inner value h = loop.last(h_) # shape {batch,dim} out = loop.stack(h_) # shape {time,batch,dim}
Maybe this example (esp. for the doc.) could be extended with another (non-rec) module, showing how to use it in a Module, so that someone who did not work with Rec in this way yet directly has an example which can be modified. Maybe something like (assuming this is how its intended):
def EncDec(Module): def __init__(): self.enc = Linear(dim) self.dec = MyRec() def forward(x): x = self.enc(x) with self.dec.loop(): h_ = rec(x) # shape {batch,dim}. this represents the inner value h = loop.last(h_) # shape {batch,dim} out = loop.stack(h_) # shape {time,batch,dim} return out
Yes sure. Although I thought this part was clear.
Also: Is there anything special needed for Choice (ChoiceLayer)? And beam search? Maybe not and it is all straight-forward already.
We are mainly aiming for conversion to RETURNN here, so maybe if we choose a smart way to include them in the loop and then some internal handling this can be converted (since in returnn they are part of the rec unit anyways).
No, simplicity, clear and straight-forward behavior has priority over everything. In the end, everything can be converted to RETURNN, because RETURNN should be generic to allow for everything (and if that is not the case, we should simply extend RETURNN itself). Esp if the RETURNN behavior is unintuitive or not straight-forward, we should not just copy that here.
I'm not sure what you mean by "smart way", and "internal handling" for what?
Right now, you would just use Choice
as you would use other layers like Linear
. I was wondering whether we need any special behavior for Choice
here in returnn-common, or not. Currently I don't really see a reason for special behavior. But maybe I'm overlooking sth.
This is of course related to the search aspect. Although, even when we can set the search-flag now explicitly in returnn-common, the logic how the ChoiceLayer
behaves with search-flag or without, that is already all part of RETURNN, so no special handling needed for that here.
Note: I changed the original suggestion of automatically stacking the output, and automatically implying the construction of the loop, because it was not clear for sub modules with hidden state when they should create a new nested loop or reuse the same loop. So now this is explicit, via loop()
.
However, I'm not sure if this is now the most intuitiv, clean, straight-forward way. Although other ways I can think of right now are also not really better.
When we keep using Module
everywhere and not use Rec
as a base class, and Rec.State()
and Rec.Loop()
explicitly, this is would work already in this draft.
(We can still use Rec
as a namespace to cover all rec related logic. Or maybe another Python module/package for this. Or just use RecLoop()
instead of Rec.Loop()
. But this is not relevant for now.)
It's just a technical question then about how to handle this code:
# Definition for a single step:
class MyRec(Module):
...
# Now the loop:
rec = MyRec()
with Rec.Loop() as loop:
x_ = loop.unstack(x)
y_ = rec(x_)
y = loop.stack(y_)
Or, alternatively to Rec.Loop()
, maybe Rec
stays a Module
where you define the inner step
, which would just exactly contain that code:
# Definition for a single step:
class MyRec(Module):
...
# The loop:
class MyRecLoop(Rec):
def __init__(self):
self.rec = MyRec()
def step(self, x):
x_ = self.unstack(x)
y_ = self.rec(x_)
return self.stack(y_)
loop = MyRecLoop()
y = loop(x)
Currently I don't really see a reason for special behavior. But maybe I'm overlooking sth.
I agree with that, right now it does not seem like we need it, but again I can report back once I finished my config wether there is some weird behavior we didnt think of.
I'm not sure what you mean by "smart way", and "internal handling" for what?
That was weird phrasing, what I meant pretty much is that I would have all of that Choice
and beam search handling in ´returnn_common´ and not leave it to the user to find the "correct" way to use the layers, best case it should just be used as you said similar to a Linear
layer imo.
Also, in some earlier draft, there was a separate init
(or initial_state
, or step0
) method to define the initial state, next to step
, with the same arguments as step
(i.e. input x
in the example). I changed that now, and merge the logic into step
, because:
x
, e.g. some embedding, which you use both for the initial state and other things in the loop, this is possible now. This was not really possible with a separate init
method.init
would have been called seemed to require a bit too much unintuitive magic for me.init
and step
.However, having both logic merged in step
also has downsides. E.g. maybe this is a bit unintuitive.
Also, I wonder whether it make sense to decouple the initial state logic more. Maybe for some given rec module, the user later wants to overwrite just the initial state with some custom logic. This is an important aspect. This aspect is also relevant for rec sub modules. This can be recursive.
Edit: Ok, we can simply allow that init
of a state (e.g. rec.h.init(...)
) can be called multiple times. Also, the loop
object can collect all state vars and allow for easy overwriting of all initial state vars recursively. Then is is basically solved.
Do we require from the user that a rec module is written in a way that it can operate automatically both on a frame-by-frame level (this is the natural way) and also on a whole sequence-level?
Or do we actually require anything special for this, or this is anyway already all automatic?
Most (all?) builtin modules (e.g. Linear
etc) already handle this fine. This is also for modules with internal state like Lstm
.
Although we should maybe differ between modules where it does not matter at all (e.g. Linear
) and which really have different implementations and logic (Lstm
).
Maybe this is actually bad, e.g. in the case of Lstm
, that this has different automatic implied behavior? The current behavior is implied whether there is a time dim (assumes outside loop, operates on the time dim) or not (assumes inside loop). This might be problematic for nested loops. But maybe not.
I think it's actually fine when a module/layer always derives the behavior from the shape, e.g. whether there is a time dim, and independent of the context (whether in a loop or not). And this is maybe already the case. Or if not, would cause errors/exceptions, and these can be fixed individually, on RETURNN side.
What happens though when a rec module is called without an outer loop? It would simply just calculate one step, on the input? And use the defined initial state? Is there any problem with that? What happens on unstack
?
Consider this example for a LSTM on a single step:
class Lstm(Rec):
def __init__(self):
super().__init__()
self.h = self.state({batch,dim})
self.c = self.state({batch,dim})
self.ff_linear = Linear(dim * 4)
self.rec_linear = Linear(dim * 4)
def step(self, x):
# x shape is {batch,dim} (single frame)
x_ = self.ff_linear(x)
h_ = self.rec_linear(self.h.prev())
x_in, g_in, g_forget, g_out = split(x_ + h_, 4)
c = self.c.prev() * sigmoid(g_forget) + tanh(x_in) * sigmoid(g_in)
self.c.assign(c)
h = tanh(c) * sigmoid(g_out)
self.h.assign(h)
return h
Now when called with an input which has a time-dim, should that automatically introduce the loop around it? No, there should not be any fancy automatic implied logic. The automatic optimization would happen on RETURNN side. I think this should be up to the user, to explicitly add code for that.
To answer the original question: No, we do not require this.
But how is the behavior of state.h.prev()
, state.h.init()
, state.h.assign()
defined exactly in each case? I.e. without any loop, with a loop, or multiple nested loops. Or with multiple calls to the module?
To extend the original question: Do we allow this? Can the user write this Lstm
module such that it works both for a single step and a whole sequence?
The input to step
for rec modules is currently defined to be outside the loop, or to be the same in each frame. Or is it actually? In the example, x
is outside the loop. But is that relevant for the loop (except if the user calls unstack
on it)?
For rec sub modules, the input could also be some value from within the current loop (as in the example for rec sub modules).
Do we need to take any special care here? Or is all just fine?
I think this is all fine. I don't see how this has any relevance.
Why do we need self.state
, and not a global Rec.State
? The state would always be related to the most inner loop after a rec module call (where the state assign
happens).
Do we always need the Rec
base class? We could use a normal Module
, and just forward
instead of step
.
Also a Module
might define states (via Rec.State
).
Whether it has "state", this can not depend on the input, because the state is defined in __init__
.
Although, whether and how the state is used, this can be dynamic in the forward
function.
E.g. a Lstm
step as a module for a single step:
class Lstm(Module):
def __init__(self):
super().__init__()
self.h = Rec.State({batch,dim})
self.c = Rec.State({batch,dim})
self.ff_linear = Linear(dim * 4)
self.rec_linear = Linear(dim * 4)
def forward(self, x):
# x shape is {batch,dim} (single frame)
x_ = self.ff_linear(x)
h_ = self.rec_linear(self.h.prev())
x_in, g_in, g_forget, g_out = split(x_ + h_, 4)
c = self.c.prev() * sigmoid(g_forget) + tanh(x_in) * sigmoid(g_in)
self.c.assign(c)
h = tanh(c) * sigmoid(g_out)
self.h.assign(h)
return h
lstm = Lstm()
with Rec.Loop() as loop:
x_ = loop.unstack(x)
y_ = lstm(x_)
y = loop.stack(y_)
How does the Rec
base class changes anything here?
We could also allow a loop around a normal Module
. Sth like:
mod = MyModule()
with Rec.Loop() as loop:
x_ = loop.unstack(x)
y_ = mod(x_)
y = loop.stack(y_)
Although, due to RETURNN automatic optimization, this would be equivalent to:
mod = MyModule()
y = mod(x)
So, if this is possible, this touches on the question whether we need the Rec
base class.
The question is if this is always possible or would result in problems.
Note: I changed the original suggestion of automatically stacking the output, and automatically implying the construction of the loop, because it was not clear for sub modules with hidden state when they should create a new nested loop or reuse the same loop. So now this is explicit, via
loop()
.When we keep using
Module
everywhere and not useRec
as a base class, andRec.State()
andRec.Loop()
explicitly, this is would work already in this draft. (We can still useRec
as a namespace to cover all rec related logic. Or maybe another Python module/package for this. Or just useRecLoop()
instead ofRec.Loop()
. But this is not relevant for now.) It's just a technical question then about how to handle this code:# Definition for a single step: class MyRec(Module): ... # Now the loop: rec = MyRec() with Rec.Loop() as loop: x_ = loop.unstack(x) y_ = rec(x_) y = loop.stack(y_)
Or, alternatively to
Rec.Loop()
, maybeRec
stays aModule
where you define the innerstep
, ...
Thinking further about this, I now prefer the solution to just have it all a normal Module
, no special Rec
base module class with step
. And then there is the special Rec.State
. Or instead of Rec.State
, maybe just State
. Which every normal Module
can use, in the __init__
.
And further, I think it is most straight-forward to not have any init logic inside step
/ forward
(there is no special step
anymore when we just use Module
always).
The reasoning is that I think the easiest way to think about the logic for the user is when with RecLoop() as loop:
can be seen as analog/equivalent to a for i in range(...)
or while ...
loop. There should be no special cases such as "this init calculation would only happen before the first step". The mental model for the user should really exactly allow for this direct translation. From this simple principle, we would infer all the other behavior.
This implies that there should not be any init logic in forward
, and it would really just execute a single step when inside the loop.
Further, this State
object would have just an assign
method. There is no need for a special init
. The assign
can infer from the scope where this belongs to (e.g. some init outside the loop). And also, it would just have a get
or read
method. There is no need for prev
because we can simply infer whether we need prev:
internally. Thus, if you would want some special init, you would just write it like this:
# Definition for a single step:
class MyRec(Module):
...
# Now the loop:
rec = MyRec()
rec.h.assign(...) # init
with Rec.Loop() as loop:
x_ = loop.unstack(x)
y_ = rec(x_)
y = loop.stack(y_)
Technically, some extra logic is maybe needed to define the RETURNN layer names based on this code. The Rec.Loop()
here would create the RecLayer
somehow, but this does not really have a name. The rec
here might have a name (because it usually would be part of the main module as a sub module). So we could use that name (and scope) if there is just a single module call inside the loop. If there are multiple module calls inside the loop, it is not so straight-forward. Maybe we take the first. Or we somehow introduce anonymous layer names in RETURNN (so the RecLayer
would not really have a name and not introduce a new level in the name scope hierarchy). But this aspect is only for the wrapping and layer names, and the user would maybe not so much care about this, but more about the logic of the model / calculation.
On the question whether RETURNN automatic optimization might cause any problems: RETURNN already should guarantee that it is equivalent. From the user view point, it never ever should matter whether it is optimized. Otherwise this is https://github.com/rwth-i6/returnn/issues/573. On this returnn-common level, it should not matter. (Maybe we want to introduce potential optimization also on this higher level. But this would be another separate topic, and we can deal with that later.)
If there are multiple module calls inside the loop, it is not so straight-forward.
You mean something like:
rec = MyRec()
rec.h.assign(...) # init
with Rec.Loop() as loop:
x_ = loop.unstack(x)
y_ = rec(x_)
y = loop.stack(y_)
with Rec.Loop() as loop:
x_ = loop.unstack(y)
y_ = rec(x_)
y = loop.stack(y_)
? I am not sure if we should allow this behavior, because I feel like this can cause quite some confusion. I think it would be okay to have the User use a rec
only once. Or maybe I am missing the point here.
Maybe we want to introduce potential optimization also on this higher level. But this would be another separate topic, and we can deal with that later.
If we do we should be careful not to have these two optimizations interfere with eachother. The only case I can think of is optimization RETURNN is not able to do because of the config structure, but from what I know there should not be these cases? Otherwise we might optimize something RETURNN would currently optimize and after changing RETURNN behavior (for whatever reason) this might then not be working anymore (since RETURNN would like to optimize differently).
If there are multiple module calls inside the loop, it is not so straight-forward.
You mean something like:
rec = MyRec() rec.h.assign(...) # init with Rec.Loop() as loop: x_ = loop.unstack(x) y_ = rec(x_) y = loop.stack(y_) with Rec.Loop() as loop: x_ = loop.unstack(y) y_ = rec(x_) y = loop.stack(y_)
No. I was referring to the case of multiple module calls inside a loop, e.g. like:
rec1 = MyRec1()
rec2 = MyRec2()
with Rec.Loop() as loop:
y1_ = rec1(x)
y2_ = rec2(y1_)
y = loop.stack(y2_)
But I was anyway just referring to the question of how this is mapped to RETURNN layers. Which is not so important, just a technical detail.
The (model/computation) behavior in all the examples is very clear. This always follows from the principle I stated above (Rec.Loop()
simply corresponds to a Python while
-loop).
? I am not sure if we should allow this behavior, because I feel like this can cause quite some confusion.
I don't think there is any confusion? Just think of Rec.Loop()
as a while-loop. It's very clear what happens then.
Maybe we want to introduce potential optimization also on this higher level. But this would be another separate topic, and we can deal with that later.
If we do we should be careful not to have these two optimizations interfere with eachother. The only case I can think of is optimization RETURNN is not able to do because of the config structure, but from what I know there should not be these cases? Otherwise we might optimize something RETURNN would currently optimize and after changing RETURNN behavior (for whatever reason) this might then not be working anymore (since RETURNN would like to optimize differently).
I have not really thought this fully through. I was more thinking about the case when the user might want to provide a more efficient implementation for a whole sequence (e.g. for the LSTM example). Or more relevant maybe for self-attention (see also the referenced issues above). But I think we don't have to consider this now.
Custom ending condition, how to define?
Maybe give an optional parameter for
Rec
likeRec.custom_break
which will be checked internally at the beginning of the loop?In RETURNN, the layer is called
end
. So maybe additionally tostep
, the user can define a methodend
(kind of similar astf.while_loop
getsbody
andcond
).This
end
method can be optional. E.g. whenunstack
is used at some point, then the sequence length is implied from this. The same is true when cross-entropy loss is used, asunstack
would be used on the targets.
To update on this: When we now stick to Module
as base class (no Rec
), I would not use a special additional end
method (the Module
just has the forward
method).
I think I would instead make it an method for the Loop
object. Like this:
rec = MyRec()
with Loop() as loop:
y_ = rec(x)
y = loop.stack(y_)
loop.end(y_ == EOS) # example end condition
There is also the question on the include_eos
option of the RecLayer
. This might be an option for Loop
(like Loop(include_eos=True)
. Or as a method, like loop.set_include_eos(True)
.
Also related: rwth-i6/returnn#391 rwth-i6/returnn#545
I.e. the question about how to define masked self-attention in a generic/flexible way in this
Rec
concept. E.g.HistoryMaskLayer
or so. This could all be wrapped directly but is this intuitive then? How does it look like? Can we make it more intuitive?
Extending on this, example draft:
with Loop() as loop:
# x is [B,D] inside, [T,B,D] outside
qkv = Linear(2*K+V)(x) # [B,2*K+V] inside, [T,B,2*K+V] outside
q, k, v = split(qkv, size_splits=[K,K,V]) # [B,K|V] inside, [T,B,K|V] outside
k_accum = cum_concat(k) # [B,T',K] inside, [B,T'',K] outside
v_accum = cum_concat(v) # [B,T',V] inside, [B,T'',V] outside
energy = dot(q, k_accum, red1="static:-1", red2="static:-1", var1="T?", var2="T") # [B,T] inside, [T,B,T'] outside
att_weights = softmax_over_spatial(energy) # [B,T] inside, [T,B,T'] outside
att = dot(v_accum, att_weights, red1="T", red2="stag:history", var1="static:-1", var2=[]) # [B,V] inside, [T,B,V] outside
Note that the cum_concat
behavior is conceptually similar to our loop.stack
. Maybe this should be unified?
Note that we have the principle that the user should not need to think about the automatic optimization (https://github.com/rwth-i6/returnn/issues/573).
The axes descriptions and axes themselves still need to be worked out (see referenced RETURNN issue above).
softmax_over_spatial
(SoftmaxOverSpatialLayer
) needs to be clever about the history time axis.
We should also work out how masked computation on the example of transducer with SlowRNN and FastRNN looks like. Again, in this example, what we want:
Loop()
and Choice()
(or choice()
) on the probability distribution.Example code draft:
x # shape {batch,enc_time,dim}
slow_rnn = SlowRNN()
fast_rnn = FastRNN()
blank_pred = Linear(1)
non_blank_pred = Linear(...)
t = State({Batch}, dtype=int32, initial=0)
align_label = State({Batch}, dtype=int32, initial=0)
with Loop() as loop: # over alignment labels
x_t = x[t] # shape {batch,dim}
with MaskedComputation(mask=(align_label.get() != BLANK)):
slow = slow_rnn(align_label.get(), x_t)
fast = fast_rnn(align_label.get(), x_t, slow)
blank_pred_energy = blank_pred(fast)
log_prob_blank = log_sigmoid(blank_pred_energy)
log_prob_not_blank = log_sigmoid(-blank_pred_energy)
log_prob_non_blank_labels = log_softmax(non_blank_pred(fast))
log_prob_combined = concat(log_prob_non_blank_labels + log_prob_not_blank, log_prob_blank)
align_label.assign(choice(log_prob_combined, input_type="log_prob"))
loop.end(t >= x.seq_len)
This is not so much about MaskedComputation
here (see #23 about this) but esp on the question whether there is anything special w.r.t. the Loop
concept, or whether this would just work as-is. Maybe that is already the case.
One aspect we should define more clearly is the behavior of State.assign
. This currently depends on the context.
Resolved: No.
So far this is clear. However, how about masked computation (#23) or conditional code (#24)? This probably should be disallowed.
Edit Just disallowing this is not really an option. Any custom module could define own hidden state and this would not work then. E.g. in the example above, slow_rnn
definitely would have hidden state. Imagine that SlowRNN
is like the Lstm
defined before.
Edit I think actually the wanted should-be behavior is clear in all cases. Just translate the conditional code with if
/else
. And the masked computation is kind of like if mask:
where else:
would just repeat the previous.
Edit On the behavior of State
and State.assign
: This also goes along with the behavior of State.get
. State.assign
itself should (must) be allowed everywhere. The sequence of get
and assign
and their corresponding context matters.
Edit So a bit generalizing and specifying:
State.assign
overwrites the previous assign when done in the same context and there was no read
(get
) in between.
Although maybe we don't really need to care about overwrites?
Contexts are hierarchical, and ordered. So far we have Loop
(this proposal here), Cond
(#24) and MaskedComputation
(#23). Maybe more later? In TF, as far as I know, there is only tf.while_loop
and tf.cond
.
There is an implicit first assign
via the initial value. And also this assign
is in the root context. So assign
is always before read
. I.e. for some read
, there is always a previous assign
, in the same or upper context.
First read
in Loop
ctx will correspond to the last assign
in the ctx.
Resolved: This is wrong
Edit Actually, is there anything special about State
? Don't we need the same logic even in general? Consider the following code:
y = const(0)
with Loop():
y = y + 1
return y
What would you expect? Or consider this:
with Cond(...) as cond:
y = const(1)
with cond.false_branch():
y = const(2)
return y
y
would be a LayerRef
. Accessing it / reading it would then consider its definition and the context of its definition, and translate this to the current context.
This probably cannot work like this because we cannot really catch the y
assignment (y = ...
), so we also do not know what is a reassignment of the same local Python variable. (It works for object attributes, e.g. like net.y = ...
, but not sure if this would be so clear...)
So to answer the question, yes, State
is special. There are different ways to design this, though. Maybe State
is also a bad name, and it should be StateVar
or just Var
or so.
State
can derive from LayerRef
, though. The read
(or get
) basically corresponds to get_name
. Maybe this API should be refactored (renamed) a bit. So we can catch that. And we would add assign
.
Edit Back to behavior definition: The fist read
(or get_name
) inside a Loop
ctx would result in "prev:..."
iff there is an assign
later in the Loop
ctx.
This implies that the layer dict can not be created right away because we have to wait until the Loop
ctx is finished to know whether there is an assign
later.
Or we just always make it output prev:...
? No... We should not make it a recurrent var when not needed.
Maybe the user could make it explicit. Like State.set_stateful()
inside the Loop
ctx before the first read
(get_name
). Then it would result in "prev:..."
. And we assert an assign
sometime later. Or if this is not set, we would error on any assign inside the loop. This would allow that the layer dict can be created right away, and would probably be simpler.
Resolved This can maybe also be done automatically on all State
members (attribs) of a Module
when the module is called inside the loop. Because we assume that a State
member of a module would always be assigned within the module call (otherwise it doesn't really make sense that the module defines this as a State
).
There is still the question on behavior when there are multiple assign
s in the loop. The last assign
defines the value for the first read
. But how to know which is the last before executing all? Or again let the user make this explicit? Is this always doable? Or maybe also only allowing a single assign
inside the loop. And assert that it comes after the read
.
However, what if the module is called multiple times? This would result in multiple assign
calls. Actually the full set_stateful
, read
and assign
call sequence, maybe within the same ctx. Disallow this? What would be the work-around for the user? Or must we allow this? But then this means that we can only construct the layer dict of the loop and everything in it after the loop construction is finished.
Then we maybe also do not need set_stateful
.
Then instead of get_name
, we must use some different mechanism because some explicit read
/get
call is needed to store the context.
Resolved Is this State
only about Loop
? When set_stateful
is called inside a Cond
, what would that mean? This must be allowed, in any ctx. But is the usage pattern always the same? It would be (in the same ctx) set_stateful
, read
, assign
? We require read
before assign
, and assign
to be in the same ctx as set_stateful
. We allow read
at any other place as well. We also allow assign
before set_stateful
in an outer ctx, which would define/overwrite the initial state.
Yes, State
is only about Loop
(even though it also might require further logic or Cond
etc), and specifically to handle the "prev:..."
logic of RETURNN. For other purpose (e.g. returning something from Cond
), there would be other separate ways.
Ok, posting as a new cleaned up comment on State
, summarizing from the previous discussion/thoughts.
State
is only about Loop
(even though it also might require further logic or Cond
etc), and specifically to handle the "prev:..."
logic of RETURNN. For other purpose (e.g. returning something from Cond
, see #24), there would be other separate ways.
Contexts are hierarchical, and ordered. So far we have Loop
(this proposal here), Cond
(#24) and MaskedComputation
(#23). Maybe more later? In TF, as far as I know, there is only tf.while_loop
and tf.cond
.
State
would not derive from LayerRef
but the ILayerMaker
would know how to handle State
(so the user does not need to call State.get
explicitly).
State
methods:
read
or get
assign
assign
and read
are allowed in all context (Loop
, nested Loop
, Cond
, etc).
However, not all possible sequence of calls (with mixing different contexts) are allowed.
Behavior of State.assign
. This currently depends on the context.
State.assign
overwrites the previous assign when done in the same context and there was no read
(get
) in between.
Although maybe we don't really need to care about overwrites?
There is an implicit first assign
via the initial value. And also this assign
is in the root context. So assign
is always before read
. I.e. for some read
, there is always a previous assign
, in the same or upper context.
First read
in Loop
ctx will correspond to the last assign
in the ctx.
The fist read
(or get_name
) inside a Loop
ctx would result in "prev:..."
iff there is an assign
later in the Loop
ctx.
This implies that the layer dict can not be created right away because we have to wait until the Loop
ctx is finished to know whether there is an assign
later.
It must be possible to call a module is called multiple times. This would result in multiple assign
calls. Actually the full (set_stateful
), read
and assign
call sequence, maybe within the same ctx. This means that we can only construct the layer dict of the loop and everything in it after the loop construction is finished.
This implies that we do not need set_stateful
.
Then instead of get_name
, we must use some different mechanism because some explicit read
/get
call is needed to store the context.
I am just now reading into the updates from the last few days, so maybe I missed it: Is there a concept for setting a target of the rec unit yet?
Now looking at the LSTM example, if I understand correctly self.c.prev()
would result in referencing c
. So pretty much what we have is a variable, most likely called the same as the layer itself. Sidequestion: Are there any cases where this same naming is unlikely/harmful? Otherwise, it kind of feels like we introduce a new variable self.c = State({batch,dim})
and 2 parts of code self.c.prev()
and self.c.assign(c)
for a statement which from reading the code semantically just means: we want to access the previous state of this layer and would in the end just write: prev:c
in the config. Of course I realized that otherwise accessing the .prev()
before the layer itself might be difficult. Right now I cannot think of a more simple direct way, but I wanted to bring this up for discussion again since this feels a bit much to me and being as straightforward as possible is one of the core goals of the project.
In general: do we really need this complicated assign logic? Shouldn't it just have two values: Initial and then the corresponding layer? Because I dont think that prev value (for writing the config) should change once it is set for one layer. Or am I missunderstanding here?
Is there a concept for setting a target of the rec unit yet?
I'm not sure I understand. What do you mean or intend? Define a loss inside the rec layer? Why is there anything special about it? This would follow just the same logic as for other losses (see e.g. #9). You would define CE loss based on data:classes
. It roughly would look like this:
x = ...
targets = get_extern_data("classes")
with Loop() as loop:
...
prob = ...
if train:
targets_ = loop.unstack(targets)
loss = ce(targets_, prob)
loss.mark_as_loss()
You would not use the target
option of RETURNN layers.
(Maybe this example should also be moved to the initial post above.)
Now looking at the LSTM example, if I understand correctly
self.c.prev()
would result in referencingc
. So pretty much what we have is a variable, most likely called the same as the layer itself.
Yes, a local variable in the Python sense.
Note that prev
was changed to get
or read
in some of the later discussion. I just forgot to update the initial example. I think get
makes this even more clear that this behaves logically like a variable. The logic with "prev:..."
in RETURNN is hidden here and would be applied automatically when needed (which is only when this is used inside a Rec loop, for the first get
).
About the naming, I'm not exactly sure. Maybe. Probably. But not so important. I'm making #25 a priority here.
Sidequestion: Are there any cases where this same naming is unlikely/harmful? Otherwise, it kind of feels like we introduce a new variable
self.c = State({batch,dim})
and 2 parts of codeself.c.prev()
andself.c.assign(c)
for a statement which from reading the code semantically just means: we want to access the previous state of this layer and would in the end just write:prev:c
in the config. Of course I realized that otherwise accessing the.prev()
before the layer itself might be difficult. Right now I cannot think of a more simple direct way, but I wanted to bring this up for discussion again since this feels a bit much to me and being as straightforward as possible is one of the core goals of the project.
I don't quite understand what you mean. Are you arguing about the naming? Or what do you think is too complicated? To me this looks very simple and straight-forward now. Compare the very first example in the initial post. The standard Python code and our rec loop code look conceptually very similar, and both just equally simple.
In general: do we really need this complicated assign logic? Shouldn't it just have two values: Initial and then the corresponding layer? Because I dont think that prev value (for writing the config) should change once it is set for one layer. Or am I missunderstanding here?
As said, it behaves logically like a Python variable. You can assign a Python variable multiple times. x = 1; ...; x = 2; ...
is valid code.
You could say, ok we disallow multiple assigns to keep it simple. Only one single assign.
But then when the user writes e.g. this in the config:
lstm = Lstm(...)
with Loop() as loop:
layer1 = lstm(loop.unstack(x))
layer2 = lstm(layer1)
output = loop.stack(layer2)
This would not work because the second call to lstm
would assign the same state var a second time.
Or you somehow introduce a separate/independent state var for each lstm
call here. But then it would also not behave as you would naturally expect.
So, to make behavior correct and as expected, you need to allow multiple assign calls.
This is the goal. To have the behavior straight-forward and as expectable, with a simple mental model.
Note that I'm mostly working now on the generalized self attention on RETURNN side (https://github.com/rwth-i6/returnn/issues/391) and related (required) things like extended dimension tag support (e.g. https://github.com/rwth-i6/returnn/pull/577).
I want to have this generalized self attention in RETURNN ready to see if this also fits nicely into this rec loop concept or if we should adapt or change sth. The generalized self attention in RETURNN might introduce some new tricky things like a special dim tag for inside the loop (rec-history dim tag).
The extended dimension tag support in general in RETURNN will also become relevant here because in returnn-common, I want to put much more emphasis on this, and maybe even disallow (or deprecate) any other way. See #17.
What do you mean or intend?
When not setting target
for the Rec unit in the current config this causes an error with the end layer during training, that batch shapes (None, )
and (None, 1)
do not match.
I don't quite understand what you mean. Are you arguing about the naming? Or what do you think is too complicated?
I was thinking each prev:
later in the RETURNN config introduces at least 3 new lines of code, which in total would sum up quite quickly. But now I understand that you want this State to be more abstract and more powerful than just tracking the prev
case, in that case these 3 lines are not much, I agree with you. Most likely my problem right now is, I dont really see how this would translate to a RETURNN config. I can see that you try to do more with State here, but right now I am not sure what for, which if I understand you correctly is the current goal .Can you maybe give an example where we would want to do more with the State
?
This would not work because the second call to lstm would assign the same state var a second time. Or you somehow introduce a separate/independent state var for each lstm call here. But then it would also not behave as you would naturally expect. So, to make behavior correct and as expected, you need to allow multiple assign calls.
I am not sure wether I understand this example correctly. Right here we dont have a state variable, so why would we assign it twice if we only call the assign command once? Or do you mean something like (assuming the state var is self.state
) where we do self.state.assign(layer1)
and then self.state.assign(layer2)
. In that case I would argue, that reusing the same state for two different layers might also make things complicated. Most likely I am missing something here.
What do you mean or intend?
When not setting
target
for the Rec unit in the current config this causes an error with the end layer during training, that batch shapes(None, )
and(None, 1)
do not match.
I don't really understand. It sounds like a bug. But anyway, things like this are really irrelevant for the discussion here. We should just fix or extend RETURNN in any way needed.
I was thinking each
prev:
later in the RETURNN config introduces at least 3 new lines of code,
Exactly like in the pure Python code in the beginning. Which is extremely simpel?
which in total would sum up quite quickly.
How? The user probably rarely would write such code. Most State
s would be introduced by some existing Module
s and the user would just use those Module
s.
Most likely my problem right now is, I dont really see how this would translate to a RETURNN config.
Well, speaking purely abstractly: RETURNN is generic enough to allow just any possible construction, or if not, it should be, and should be extended to allow that.
So this should never be a concern here in this repo (how it translates to a RETURNN config). The main goal is to have it logical, simple, straight-forward, expectable. How it translates to RETURNN is an implementation detail, after we have clarified what design we want.
Speaking more specifically now, how we actually do the implementation: I think most aspects were already discussed here (the thread gets too long to easily see that...). There might be some open questions but I think they can be solved easily.
I will start working on this once I finished the generalized self attention logic in RETURNN.
I can see that you try to do more with State here, but right now I am not sure what for, which if I understand you correctly is the current goal .Can you maybe give an example where we would want to do more with the
State
?
I just gave you one? This is one such example:
lstm = Lstm(...)
with Loop() as loop:
layer1 = lstm(loop.unstack(x))
layer2 = lstm(layer1)
output = loop.stack(layer2)
I am not sure wether I understand this example correctly. Right here we dont have a state variable, so why would we assign it twice if we only call the assign command once?
We have. The Lstm
introduces it. See the Lstm
definition above. E.g. there is then lstm.c
. And the lstm(...)
call results in lstm.c.assign
.
This has multiple state assign calls because you call the same module multiple times which reuses the same hidden state.
Although, maybe this is actually not such a good example. I wonder now. I'm thinking about e.g. Universal Transformer now, where you also would call the same self attention layer multiple times. But there you do not want to have it use the same hidden state, but every call (layer) should still have its own hidden state, just the parameters should be shared.
So maybe the idea that a module can have hidden state is not good after all? Again, speaking of PyTorch, PyTorch modules can have that as well (as buffers), but usually it is never done this way. E.g. the LSTM
module in PyTorch or LSTMCell
module in PyTorch explicitly gets the prev state and cell as arguments and returns the new state and cell.
Maybe we should somehow make it more explicit, what hidden state is being used? Maybe like this:
lstm = Lstm(...)
with Loop() as loop:
with StateScope():
layer1 = lstm(loop.unstack(x))
with StateScope():
layer2 = lstm(layer1)
output = loop.stack(layer2)
Maybe we also should differentiate between buffers and state? Buffers (in general, and also in PyTorch) are intended to not really have influence on the behavior. Or not on the model behavior at least. I'm actually not sure sure on the typical usage of buffers in PyTorch. But they are definitely not used as hidden state. Hidden state is always passed explicitly.
See the Lstm definition above.
Okay no clue how I missed that this references the definition of LSTM you gave here, handled it like a blackbox... now things make a lot lot more sense. Thats I think what caused most of the confusion.
So this should never be a concern here in this repo (how it translates to a RETURNN config). The main goal is to have it logical, simple, straight-forward, expectable. How it translates to RETURNN is an implementation detail, after we have clarified what design we want.
I see, then I think I do get it now. So (to verify I am not mistaken) pretty much the state is a concept of returnn common which then in some cases might translate to prev:
in RETURNN, but is a lot more powerful and could do a lot more than just that. Prev:
is only one of the possible cases State can be used in. Using this definition I agree with you, these lines code for this is simple a straight forward.
Maybe we should somehow make it more explicit, what hidden state is being used?
I dont know. If we really want to end up with a module like structure I feel like hidden states should be something the user should not really need to deal with himself if he is not changing basic configuration of that module (which shouldnt happen too often, because users usually should put their model together from modules without needing to make too many new ones imo). Adding something like with StateScope()
would, if not needed for some logic/expressiveness make it more confusing I think.
So (to verify I am not mistaken) pretty much the state is a concept of returnn common which then in some cases might translate to
prev:
in RETURNN, but is a lot more powerful and could do a lot more than just that.Prev:
is only one of the possible cases State can be used in. Using this definition I agree with you, these lines code for this is simple a straight forward.
Yes. In actually most of the common cases here in returnn-common it would result in "prev:..."
. I'm more thinking about the maybe somewhat unusual cases.
Maybe we should somehow make it more explicit, what hidden state is being used?
I dont know. If we really want to end up with a module like structure I feel like hidden states should be something the user should not really need to deal with himself if he is not changing basic configuration of that module (which shouldnt happen too often, because users usually should put their model together from modules without needing to make too many new ones imo). Adding something like
with StateScope()
would, if not needed for some logic/expressiveness make it more confusing I think.
There are multiple aspects/concepts:
Module
class.Module
object (instance of the class).
A module object can have parameters, and sub modules.
We propose here that it can also have state.Module
object call. This performs the computation by using the model parameters.
Calling the same module object again will reuse the same model parameters.
We propose that it also uses the same state. Although that is what I question now a bit.So the model parameters are shared in this example:
lstm = Lstm(...)
with Loop() as loop:
layer1 = lstm(loop.unstack(x))
layer2 = lstm(layer1)
output = loop.stack(layer2)
In the current proposal, also the hidden state is shared in this example.
However, maybe the more common use case is that the user wants to share the parameters but not the hidden state. Again, I'm thinking about Universal Transformer. But also most other example I can think of where parameters are to be shared, you would not want to share the hidden state.
I think it should be simple to write code for the common cases, while also allowing for exotic things, while also being straight-forward/logical.
How would you actually implement this common case, where you want to share parameters but not the hidden state? Maybe like this:
lstm = Lstm(...)
with Loop() as loop:
# Introduce separate state vars such that each layer has own hidden states.
# Copy state var now to have the right initial state.
h1, c1 = lstm.h.copy(), lstm.c.copy()
h2, c2 = lstm.h.copy(), lstm.c.copy()
lstm.c, lstm.h = h1, c1 # assign own hidden state vars
layer1 = lstm(loop.unstack(x))
lstm.c, lstm.h = h1, c1 # assign own hidden state vars
layer2 = lstm(layer1)
output = loop.stack(layer2)
I'm not sure if this is still so easy. Or I would rather say no.
Note, in PyTorch, this (sharing parameters but not hidden state) would look like:
lstm = LstmCell(...)
layer1_state = lstm.initial_state()
layer2_state = lstm.initial_state()
for x_ in x:
layer1, layer1_state = lstm(x, layer1_state)
layer2, layer2_state = lstm(layer1, layer2_state)
output = loop.stack(layer2)
In PyTorch, sharing parameters and hidden state would look like:
lstm = LstmCell(...)
lstm_state = lstm.initial_state()
for x_ in x:
layer1, lstm_state = lstm(x, lstm_state)
layer2, lstm_state = lstm(layer1, lstm_state)
output = loop.stack(layer2)
In both cases, you make the handling of state explicit. So there is no confusion on the behavior of state, because it is always explicit.
So, I'm wondering if we also should make it always explicit. But still in a generic way. That is why I proposed StateScope
. Maybe StateScope
should be passed to the module call via state=StateScope()
instead of using with
here. Or maybe call it StateHolder
or so. Which can hold any nested structure of state vars. The idea with with
was that it would be nested and automatically use for any nested sub module calls. I agree that the initial example code for StateScope
above is not so clear. This should be worked out.
This is all about parameter sharing, i.e. calling the same module multiple times. And how the hidden state should be handled in this case. Because we hide the hidden state away in the current proposal, such module call has side effects on internal state. Generally speaking, side effects are bad, because it makes it more difficult to reason about the behavior of code. Usually you want that some call like lstm(...)
does not have side effects.
Btw, this argument is also about module calls in general. Not really about rec loop actually. E.g. this code in the current proposal:
lstm = LstmOnSeq(...)
layer1 = lstm(x)
layer2 = lstm(layer1)
output = layer2
It again would share the parameters (as expected) but also the hidden state (not sure if that is expected, although it follows logically from the current proposal).
In PyTorch, the example of sharing the parameters but not the hidden state would look like:
lstm = Lstm(...)
layer1, _ = lstm(x) # default initial hidden state
layer2, _ = lstm(layer1) # default initial hidden state
output = layer2
Ok, following from those thoughts, I'm now thinking that we definitely should make it explicit. For all modules which have state (by themselves, or via sub modules). And state is not an attrib of the module but just an explicit parameter and return value of a module call.
There can be reasonable default initial state. E.g. when you have def forward(self, ..., state=None)
in a module, you can check if state is None: state = some_default_initial_state()
or so. Every such module would return the new state. E.g. like a tuple as in the PyTorch LSTM example.
The example for sharing parameters but not sharing hidden state would look sth like:
lstm = Lstm(...)
layer1_state = StateScope()
layer2_state = StateScope()
with Loop() as loop:
layer1, layer1_state = lstm(loop.unstack(x), layer1_state)
layer2, layer2_state = lstm(layer1, layer2_state)
output = loop.stack(layer2)
Which is very similar to the PyTorch example.
From the point of view of the model implementation, it is a bit strange now that there is a conceptual difference between arguments which are state (and thus not normal LayerRef
instances but State
or StateScope
). And maybe sometimes you want that one argument becomes a state, and maybe in other times a different argument becomes a state, and maybe sometimes you do not even want that any of the arguments are states, e.g. calling lstm(x, (prev_h, prev_c))
where prev_h
and prev_c
are just other normal LayerRef
s.
So this is a problem. I think the model definition should handle it all the same and just expect always LayerRef
s (or nested structures of LayerRef
s).
But then how to we handle this? Maybe more explicitly:
lstm = Lstm(...)
layer1_state = State(lstm.initial_state())
layer2_state = State(lstm.initial_state())
with Loop() as loop:
layer1, layer1_state_ = lstm(loop.unstack(x), layer1_state.get())
layer1_state.assign(layer1_state_)
layer2, layer2_state_ = lstm(layer1, layer2_state.get())
layer2_state.assign(layer2_state_)
output = loop.stack(layer2)
State
here would be a bit extended that it can also handle nested structures.
lstm.initial_state
would still return a normal LayerRef
.
I think this would be very logical and straight-forward now again. And all is explicit and behavior is always clean.
Other modules (e.g. Lstm
etc) would never deal with the State
concept. The only point where you deal with that is when you write a rec loop, i.e. Loop()
. Which is probably also encapsulated away in some module, like Decoder
, and the user in most common cases probably just writes sth like:
encoder = Encoder(...)
decoder = Decoder(...)
output = decoder(encoder(...))
I wonder a bit if this is too complicated now, to write the Loop
and explicitly handle the state this way. But I'm not really sure how to make this simpler.
Or maybe like:
lstm = Lstm(...)
loop_state = State()
loop_state.layer1 = lstm.initial_state()
loop_state.layer2 = lstm.initial_state()
with Loop() as loop:
layer1, loop_state.layer1 = lstm(loop.unstack(x), loop_state.layer1)
layer2, loop_state.layer2 = lstm(layer1, loop_state.layer2)
output = loop.stack(layer2)
This would add a bit of clever handling into the State
object. Basically the assign
and get
calls here are covered by __setattr__
and __getattr__
. This allows to write the loop code a bit shorter and maybe more readable.
I was quite busy with all the recent dim tag work (https://github.com/rwth-i6/returnn/pull/577) and generalized self attention (https://github.com/rwth-i6/returnn/issues/391). Generalized self attention is finished now, and the concept of consistent dim tags has been improved and extended a lot. Although there are still on-going discussions on taking this even further (https://github.com/rwth-i6/returnn/issues/597, https://github.com/rwth-i6/returnn/issues/632). Some of these might also be relevant for returnn-common but we should discuss this independently, maybe in #17.
I lost a bit track on the current state here, and the reasoning.
I think the final conclusion on the hidden state was to have it always explicit as a (state
) argument to the module call, and also explicitly return the new state. So very much like PyTorch. So it is not really hidden at all.
I like it that way because there is no ambiguity in the code, how the hidden state is handled. This is all explicit.
The initial post description is maybe outdated on this. Do we still need to have the state as an attribute to the module (e.g. self.h = State({batch,dim})
)? Why?
And why do we need this special State
object with assign
and get
calls? Or the simplified code which uses __setattr__
and __getattr__
on some special State
object?
I can see that the state
module call argument might not just be a single LayerRef
but it could also be some nested structure, esp if this is a module with other sub modules which also could have state. But this does not really explain why we need State
. Maybe the concept of the State
object was just introduced for the initial idea where we did not want to make this explicit, where the hidden state would have been all implicit and hidden? And now not needed anymore?
Edit The last comment just before this actually says:
state is not an attrib of the module but just an explicit parameter and return value of a module call.
So as I argued above, we would not have this module attrib anymore (like self.h = State({batch,dim})
).
States are always explicitly passed to a module call, and new states are returned from it.
However, I don't understand this anymore:
From the point of view of the model implementation, it is a bit strange now that there is a conceptual difference between arguments which are state (and thus not normal
LayerRef
instances butState
orStateScope
)
Why is there a conceptual difference? Does it need to be? Why? Why can't the state
module call argument just be a regular LayerRef
(or nested structure)?
Actually I also addressed this before:
I think the model definition should handle it all the same and just expect always
LayerRef
s (or nested structures ofLayerRef
s).
But still in this comment I keep State
(or StateScope
) as a special concept (e.g. loop_state = State()
). Why is this needed?
Edit
The first example from above would look like:
lstm = Lstm(...)
loop_state_layer = lstm.initial_state()
with Loop() as loop:
layer1, loop_state_layer = lstm(loop.unstack(x), loop_state_layer)
layer2, loop_state_layer = lstm(layer1, loop_state_layer)
output = loop.stack(layer2)
The second example from above would look like:
lstm = Lstm(...)
loop_state_layer1 = lstm.initial_state()
loop_state_layer2 = lstm.initial_state()
with Loop() as loop:
layer1, loop_state_layer1 = lstm(loop.unstack(x), loop_state_layer1)
layer2, loop_state_layer2 = lstm(layer1, loop_state_layer2)
output = loop.stack(layer2)
The problem is that this does not exactly corresponds to the Python while ...:
loop as we want it to w.r.t. the Python local variables. We cannot correctly infer from this Python code that loop_state_layer
is a recurrent variable which changes inside the loop. So this is probably one reason for this StateScope
as it was suggested.
But the same problem is actually also for any other output. Anything in the loop which wants to use the value from the previous iteration. This was not really addressed before? Is this not a problem? Or was this solved differently?
Edit Ah, this is actually also in the very first proposal. That is what State
really is for. To handle the RETURNN prev:
logic in a generic way. So basically this solves #6.
So, to recap:
prev:...
in RETURNN, #6) and hidden state would be handled in the same way.State
(or StateScope
or however we call it) is only relevant for the code which directly operates with Loop()
.state
argument in the module call. Which can be any arbitrary nested structure.I still see a problem in catching errors here. When the user writes the code ignoring State
but just like before, it would compile without error, but it would do the wrong thing, i.e. not the expected behavior. It would always use lstm.initial_state()
as state in every iteration in this example.
Can we somehow catch such errors to avoid unexpected behavior?
Or is this maybe not too much a problem as this is actually not too much unexpected?
Also, do we want that we can also skip the state
argument, i.e. that it has a reasonable default? Modules might have state=None
in the function signature and then internally do sth like if state is None: state = self.initial_state()
. However, this code would have exactly the problem as just described. I.e. then it would not use the prev state but always the initial state in every iteration. Is this a fundamental problem which cannot really be solved?
In PyTorch, this is the same behavior though, right? In PyTorch, there is the difference of LSTM
(on seq) vs LSTMCell
(on a single frame). LSTM
does have this default initial state
, but LSTMCell
does not, as it does not make sense for this case. In RETURNN, we have both together, which maybe causes this confusion. But we do not need to wrap it exactly the same here in returnn-common. We could also have some LSTMCell
. Or maybe use RnnCell
or RecCell
to be able to use other rec units from RETURNN as well (not just NativeLstm2
). Or wrap them all separately. Or both. And for all of these, we require to have the state explicit as argument (no default). Although such modules would still have some function initial_state
.
I'm questioning now whether this explicit code is maybe too complicated for many of the common use cases. On the RETURNN side, the hidden state is hidden and not explicit. So when translating some old-style RETURNN model to this new way, it would take somewhat extra effort (although it should be straight-forward).
One alternative is to introduce State
inside the module call (but not as a module attrib). So a LSTMCell
or RecCell
could be defined like:
class LstmCell(Module):
def __init__(...): ...
def forward(self, input, state=None):
if state is None:
state = StateScope(...)
h, c = state.h, state.c
...
h.assign(new_h)
c.assign(...)
return new_h
Then we can use e.g. this code:
lstm = LstmCell(...)
with Loop() as loop:
layer1 = lstm(loop.unstack(x))
layer2 = lstm(layer1)
output = loop.stack(layer2)
This would do param sharing between both lstm
calls. However, it would not share the same hidden state (as expected).
The StateScope
would be attached to the Loop
(via sth like Loop.get_current() -> Loop
which we can implement via the with loop:
logic).
The not-so-nice thing about this is that we clearly differentiate between state and other input now. So it becomes complicated/non-straightforward when the user would also want to pass some custom state (custom h
or c
) to lstm
. Or how to return the new cell state c
.
It would maybe look like this:
layer1_c = get_state_scope(layer1).c.get()
So I just read through anything and I will add my comments. Since it was quite a lot, maybe I also missunderstood something, then just correct me:
So first lets start with the concept of a State
. From what I remeber we started with adding a logic which is able to handle the prev:
logic from Returnn and then started expanding on that adding more "features" like hidden state handling to it.
One of my general questions would be, why we even do explicit recurrent handling like getting state updates and so on instead of just "references" to these updates which would be passed in to Returnn for the config. Isn't the actual handling of how the LSTM unit works something the is part of Returnn and what we try to achieve in this repo is a nice way of writing it down. From what I understand right now you are also looking to include additional concepts.
From this I would also conclude my reasoning for deciding between the two variants for the hidden state: I feel like the implicit handling is one of the biggest strengths of Returnn, even though ofc. its a sharp sword to work with. Not explicitly having to worry about certain details of a layer makes it (at least from my HiWi view) quite more easy to start and also work with. What we should aim for in my opinion is an interface which is as simple as possible to "just start", but has the flexibility of allowing stronger configurations once the user is more used to it. So I feel like it would be fine to accept that basic configurations (where in this caes you don't do modifications to the internal hidden state logic) are as easy as possible but if you want to make use of some stronger concepts you would have to go a bit more in depth. The problem why this in other casees causes troubles is when the documentation is not good enough, making users trying to make the transition into more detail feel lost.
Now onto the specific example: I would prefer the second option in the general case. For the more advanced options I would then include an option to get the StateScope
or certain elements of it and also include the possibility to make modifications to it. Again I think this is more of an advanced concept which I am not sure how much it will be used. Maybe I am mistaken here. So what we could allow is doing something like:
scope = get_state_scope(layer1)
and then do stuff like scope.c.get()
like you suggested, but also maybe scope.c.set()
or even set_state_scope(layer1, scope)
to overwrite the full scope with the (modified) new scope. This would leave the whole construction flexible enough to handle these cases in my opinion without too much of a workarround.
But overall I think this is a point where we need to put some thought into. Maybe we could work our 2 or 3 concrete ways (with 3-4 examples each) and then ask other users about it, because I feel like this is something where User feedback might be meaningful to make a decision. What do you think?
So first lets start with the concept of a
State
. From what I remeber we started with adding a logic which is able to handle theprev:
logic from Returnn and then started expanding on that adding more "features" like hidden state handling to it.
Basically. But not directly. All the discussion here should be seen independent from RETURNN really. But really more about what would be a straightforward design (for people which do not know about RETURNN). It should not be that we adopt some strange thing from RETURNN only because that is how it is now in RETURNN.
So, when we think about loops (for ...:
or while ...:
), we need some way to access values from the previous iteration. And the question is, how to design that.
The next question is, whether we want to allow hidden state, which can be hidden, and thus is a separate concept, or whether there should not be a separate concept for hidden state, and it would just be the same as other values from previous iteration.
The cleanliness and straightforwardness of the design is of highest priority here. How this maps to RETURNN in the end is only secondary. We can surely map whatever we came up with, as long as it is well defined. Or if not, we can simply extend RETURNN such that we can. Although for almost everything discussed here, I think that RETURNN already supports it, so no extension or modification on RETURNN side would be needed.
One of my general questions would be, why we even do explicit recurrent handling like getting state updates and so on instead of just "references" to these updates which would be passed in to Returnn for the config.
I don't exactly understand. What do you mean by references to the updates?
The argument of explicit vs hidden/implicit is simple: Because only that way, it is straightforward. Hidden/implicit is always problematic. Esp when you want to change it, or have some control over it, it becomes unnatural. As long as you do not want to touch or access the hidden state, it doesn't matter. But as soon as you do, it matters. And there always will be such cases.
Isn't the actual handling of how the LSTM unit works something the is part of Returnn and what we try to achieve in this repo is a nice way of writing it down.
We are not changing that. Here we simply discuss how we design the handling of accessing previous values (values from the prev loop iteration), and hidden state, or whether hidden state should be handled differently or just the same as other previous values.
From what I understand right now you are also looking to include additional concepts.
No. No underlying concept is really new. It would still all map to what RETURNN does right now. Just the API is new. This is the whole point here of returnn-common. And for designing the API, we have the freedom to do it as we want. And I think we should try to prioritize cleanliness and straightforwardness.
From this I would also conclude my reasoning for deciding between the two variants for the hidden state: I feel like the implicit handling is one of the biggest strengths of Returnn, even though ofc. its a sharp sword to work with. Not explicitly having to worry about certain details of a layer makes it (at least from my HiWi view) quite more easy to start and also work with.
Many people claim that PyTorch is easier because it is all explicit. Usually nothing is hidden away. When reading other people's code, you rarely would ask yourself what it would actually do, or whether this module has some hidden state, because it is all explicit.
Explicitness can result in slightly more code but it is usually still pretty simple and short, and it is easier to follow and reason about because you don't have to think about implicit behavior.
Implicit behavior is maybe fine for all the simple cases but once it gets more complex, it can make it really hard to reason about.
I spoke with some other people and they all strictly preferred the explicitness.
What we should aim for in my opinion is an interface which is as simple as possible to "just start", but has the flexibility of allowing stronger configurations once the user is more used to it.
Yes, simplicity and flexibility are both the main goals of RETURNN, and also here of returnn-common.
However, I think you argue exactly for the opposite as I did before.
What does simple mean? Simple does not necessarily means short code. Simple is about writing code, reading code, and understanding code. It should never be ambiguous, otherwise it is not simple. It should be clear and straightforward. Straightforwardness makes it simple to write and understand. Clearness makes it simple to read.
What does flexibility means? It does not just mean that more complex things are possible. More complex things are always possible. Flexibility also means that more complex things are straightforward to do. Otherwise it is actually not really flexible, if something is not straightforward or unclear.
So I feel like it would be fine to accept that basic configurations (where in this caes you don't do modifications to the internal hidden state logic) are as easy as possible but if you want to make use of some stronger concepts you would have to go a bit more in depth. The problem why this in other casees causes troubles is when the documentation is not good enough, making users trying to make the transition into more detail feel lost.
You cannot really compensate a complicated non-straightforward design by just having better documentation. Treating hidden state as something different than non-hidden states just makes it more complicated, and not straightforward. When you have worked with non-hidden state before, it is not clear or straightforward how to work with hidden state now, when this is a different thing or concept.
Now onto the specific example: I would prefer the second option in the general case.
I actually asked someone on what behavior he would expect from this code:
lstm = LstmCell(...)
with Loop() as loop:
layer1 = lstm(loop.unstack(x))
layer2 = lstm(layer1)
output = loop.stack(layer2)
He expected that the two lstm
calls would not only share the params but also the hidden state. Which is exactly not what would happen. Or it depends on the implementation of LstmCell
. So this is a perfect example what I meant before: It is not easy to read or understand. The behavior of the hidden state is unclear and ambiguous. And it is not straightforward how to handle hidden state now.
But overall I think this is a point where we need to put some thought into. Maybe we could work our 2 or 3 concrete ways (with 3-4 examples each) and then ask other users about it, because I feel like this is something where User feedback might be meaningful to make a decision. What do you think?
Yea, I also thought about getting some more feedback. It's a good idea to prepare some examples.
In all cases, what I think is important:
So the different examples could be:
tanh
on the hidden state in between.I agree, some of these examples are maybe a bit exotic. But that is my point. It should still be straightforward to do. Otherwise it is not really flexible. In PyTorch, all of these are very simple and straightforward. In current RETURNN (dict-style net def), while all are possible in principle, only the first three are simple, while the others are definitely not, esp not straightforward. I expect and argue that whenever you have it explicit, it becomes straightforward.
Some further thoughts on the handling of state in general in a loop (orthogonal on the discussion whether hidden state should be a separate concept or not):
While assign
and get
on such a State
object are somewhat canonical, this leads to lots of boilerplate code, which makes it somewhat more complicated to write than the corresponding natural logic in normal Python (or PyTorch) code (in a normal for
or while
loop).
I'm thinking about the options to simplify that to more canonical simplified Python, while still also not doing too much magic, such that it is still clear what happens.
One approach was already proposed, which would be a StateHolder
object or so, where we do the same logic in __setattr__
and __getattr__
.
This is mostly fine, except of:
loop.state.
or so as prefix)loop.state.
as prefix, there is no error and just wrong behavior. I don't think there is a good way we could detect this as an error.Another idea I had was to pass locals()
to Loop
. At the exit of Loop
, could this detect what local vars have changed inside the loop? Then this can also be used to implement such logic.
Downsides:
b = b + 1
in the loop. So b
gets reassigned. But now IDEs (e.g. PyCharm) and code checkers would complain that the new b
is not used anymore. It is only used because of the locals magic.Some variant, which solves some of the downsides, while adding again some further function:
The user could call sth like loop.exit(locals())
explicitly at the end of the loop. This is slightly less magic, more robust (should always work), and IDEs (at least PyCharm) will also not complain about unused local vars.
I played a bit around with variations of this here.
We could do some Python-level code transformation, similar as JAX, TF tf.function
(see AutoGraph transformations), PyTorch jit
, etc. This is extremely flexible and powerful and basically allows us to do it in whatever way we want. We even can simply use normal for
or while
loops. This directly allows us to write very straight-forward Python code.
Main downside: This is a heavy and complex thing to do. This adds a lot of complexity. Also, while I have some ideas how this can be implemented, and I have implemented some similar code before (on AST level, for pytorch-to-returnn), there are various different possible approaches here, and this would also need some more research, e.g. how tf.function
does this, etc.
Just out of interest, I'm following the logic of tf.function
to the autograph transformation. This looks extremely complex, with lots of edge cases. At some point, it calls autograph.converted_call
. And after a long list of exceptions and extra checks, that calls conversion.convert
. And that calls AutoGraphTranspiler.transform_function
. Then there is FunctionTranspiler
, which seems to work on Python AST level. How does it get the AST? This looks ugly. There is inspect_utils.getimmediatesource
which uses inspect.findsource
and inspect.getblock
. Which simply tries to get the source code filename and then loads that file. Then it calls gast.parse
, where gast
is this external Python package, which seems to wrap some incompatibilities in the official Python ast
package between Python 2 and Python 3. But this is basically ast.parse
, which uses the Python compile
builtin, with flags = PyCF_ONLY_AST
. Then the AST transformation logic happens in AutoGraphTranspiler
. Maybe most interesting is the ControlFlowTransformer
which handles if
and while
.
And I just scratched the surface of tf.function
autograph. This goes much deeper.
The question is if we maybe can get away with much simpler Python AST transpile logic and code. We can maybe reuse FunctionTranspiler
. Or maybe some other Python library for transpiling.
I played a bit around with the TF transpiler code, which is generic (although it would have been better if this would be independent of TF, because this probably only exists in TF2, and also the API might not be stable). A simple example can be seen here. While this is actually not too complicated for this simple logic, I'm not sure if this is still not way too complex.
So, given these options, I tend to prefer StateHolder
with __setattr__
and __getattr__
.
How would the StateHolder
with __setattr__
and __getattr__
look like? Here some possible variations:
The Loop
object already could create that, as loop.state
. It's only inside the with
block then but this is maybe ok.
Should we allow usages without defining the initial value or the shape? Maybe. In that case, the code for 2 LSTM layers, sharing params, not sharing hidden state can look like:
lstm = Lstm(...)
with Loop() as loop:
layer1, loop.state.layer1 = lstm(loop.unstack(x), loop.state.layer1)
layer2, loop.state.layer2 = lstm(layer1, loop.state.layer2)
output = loop.stack(layer2)
Or for 2 LSTM layers, sharing params, sharing hidden state:
lstm = Lstm(...)
with Loop() as loop:
layer1, loop.state.lstm = lstm(loop.unstack(x), loop.state.lstm)
layer2, loop.state.lstm = lstm(layer1, loop.state.lstm)
output = loop.stack(layer2)
Or consider this Python code:
i = 0
for x_ in x:
i = i + 1
Equivalent code here:
with Loop() as loop:
loop.unstack(x)
loop.state.i = loop.state.i + 1
How to explicitly specify the initial value, and maybe other things like shape? Maybe this can just be extended, like so:
with Loop() as loop:
loop.unstack(x)
loop.state.i = State(shape=(), initial=0)
loop.state.i = loop.state.i + 1
This would be maybe a bit counter intuitive as loop.state.i
assignments and reads would normally expect or return a LayerRef
, and an assignment by State
is handled special.
But other variants might also look a bit inconsistent, like loop.define_state("i", initial=0)
or so. I'm not sure.
The "current" example at the top looks already quite understandable and straightforward, but I have some comments / questions:
unstack
function or the loop object itself should be able to take some kind of information what axes will be used for the loop. I think in the current example everything was based on the time axis.loop.last
should be able to get an n
parameter so that you can get the n last states (just what the window layer inside a recurrent net does now, to implement e.g. causal convolution decoders"Cond
object, but with a mask tensor matching the "unstacked" axis.The "current" example at the top looks already quite understandable and straightforward, but I have some comments / questions:
- Either the
unstack
function or the loop object itself should be able to take some kind of information what axes will be used for the loop. I think in the current example everything was based on the time axis.
Yes right. This is basically the discussion here: https://github.com/rwth-i6/returnn/issues/597, https://github.com/rwth-i6/returnn/issues/632 and related.
We still did not fully clarify whether we maybe should allow some defaults for cases where it is unique. Or basically we anyway need to do that for all existing layers to not break backward compatibility.
But anyway, this is somewhat orthogonal to the discussion here.
- the
loop.last
should be able to get ann
parameter so that you can get the n last states (just what the window layer inside a recurrent net does now, to implement e.g. causal convolution decoders"
No, I don't think so. It should follow the same principles as everything in RETURNN, it should be as simple as possible, and atomic. You can very easily get this functionality e.g. by putting a causal WindowLayer
and then get the loop.last
of that (that is anyway how loop.last
with n
would work internally).
- ~it is not clear yet to me how masks can/should be used, so if I want to update states only at a certain condition like with a future
Cond
object, but with a mask tensor matching the "unstacked" axis.~ Sorry, this is Masked computation wrapper #23
Yes, this is #23, but actually, when we have all hidden state also now explicit, i.e. no distinction anymore, this also becomes pretty straight forward even without any such wrapper. The only reason such wrapper can be useful is to allow potential further automatic optimizations (as MaskedComputationLayer
does right now).
So, a first version is implemented now.
See the test in test_rec_ff
.
It uses this code:
x = get_extern_data("data")
with Loop() as loop:
x_ = loop.unstack(x, axis="T")
loop.state.h = y_ = Linear(n_out=13)([x_, loop.state.h])
y = loop.stack(y_)
return y
Which results in this net dict:
{'loop': {'class': 'rec',
'from': [],
'unit': {'h': {'class': 'copy', 'from': 'linear'},
'linear': {'class': 'linear',
'from': ['rec_unstack', 'prev:h'],
'n_out': 13},
'output': {'class': 'copy', 'from': 'linear'},
'rec_unstack': {'axis': 'T',
'class': 'rec_unstack',
'from': 'base:data:data'}}},
'output': {'class': 'copy', 'from': 'loop/output'}}
(Sorry for the pprint
formatting...)
(Some of the layer names will probably change in some future version.)
So I'm closing this now, as we have the initial design implemented.
Please open separate issues if sth is broken, missing, or whatever.
Just for reference, also Loop.end
has been implemented now.
What's still missing is the default interface for all wrapped RETURNN layers with hidden state, which should make the state more explicit, as discussed here. This is #31.
This issue is to collect some thoughts on the recurrent loops design, which wraps the
RecLayer
with an explicit subnetwork in RETURNN.The main goal is to have this very straight-forward and simple for the user. We can abstract away from the underlying
RecLayer
if that makes things easier. We can also extend RETURNN itself if needed.Related is also #6 (rec prev mechanism), and this issue here might fix/resolve #6, although not necessarily.
This also needs some mechanism for unrolling/unstacking, i.e. when we iterate over input
x
with some time-axis, i.e. to getx[t]
. This is https://github.com/rwth-i6/returnn/pull/552.To define a loop like this pseudo Python code:
Current design:
There is
Loop()
which can be used in awith
context, which corresponds to thefor
-loop in the example, or in general to awhile
-loop. Like:There is
State()
which can define hidden state (for any module or any code).The example above can be written as:
Or with a module as:
For the TF name scopes (and variable scopes), we should follow #25, i.e. make it exactly as the module hierarchy.
The RETURNN layer name of the created
RecLayer
viaLoop
does not matter too much. It could be arbitrary, or some clever (but simple) logic to use the first module name or so. The RETURNN layer hierarchy can be independent from the actual TF name scopes (via #25).Special options for the
RecLayer
likeinclude_eos
can be options forLoop
, likeLoop(include_eos=True)
. Or as a method, likeloop.set_include_eos(True)
.Loop
(potential) methods:unstack
. We need https://github.com/rwth-i6/returnn/pull/552 for this.unstack
also implicitly implies that the loop runs over the time-axis ofx
.last
stack
idx
: to return some layer which wraps RETURNN':i'
State
has methodsget
andassign
. (... See discussion below for more ...)Current reasonings:
Why no special base class
Rec
which derives fromModule
? We want to easily allow to use any kind of module inside a loop. We think the current API makes this more straight-forward.Why is
h
not an argument offorward
, and whyState
instead? This allows to call other sub modules, which might define their own hidden state. So the root recurrent module does not need to know about all the hidden states of sub modules.Why to have the hidden state explicit, and not use sth more close to
self.prev
? To make the behavior more straight-forward.The current design allows for nested loops and sub modules with hidden state. Only the
Loop()
call actually introduces a new loop.There should not be any special handling needed for the
Choice
layer. Note that the search flag and train flag logic is a separate thing (#18).There should not be any special handling needed whether the input to a rec module call would be inside the current/same loop or not.
unstack
on some value which is already inside the loop would not make sense, though, and should result in an error. But this would all be covered by RETURNN logic already.RETURNN rec automatic optimization should not cause any problems. RETURNN already should guarantee that it is equivalent. From the user view point, it never ever should matter whether it is optimized. Otherwise this is rwth-i6/returnn#573. On this returnn-common level, it should not matter.
Example for LSTM for a single step: