rwth-i6 / returnn_common

Common building blocks for RETURNN configs, such as models, training concepts, etc
7 stars 4 forks source link

Rec design for recurrent definitions / loops #16

Closed albertz closed 3 years ago

albertz commented 3 years ago

This issue is to collect some thoughts on the recurrent loops design, which wraps the RecLayer with an explicit subnetwork in RETURNN.

The main goal is to have this very straight-forward and simple for the user. We can abstract away from the underlying RecLayer if that makes things easier. We can also extend RETURNN itself if needed.

Related is also #6 (rec prev mechanism), and this issue here might fix/resolve #6, although not necessarily.

This also needs some mechanism for unrolling/unstacking, i.e. when we iterate over input x with some time-axis, i.e. to get x[t]. This is https://github.com/rwth-i6/returnn/pull/552.


To define a loop like this pseudo Python code:

x  # given, shape {batch, time, dim}
h = Zeros({batch,dim})()  # initial state, shape {batch,dim}
out = []
for t in range(x.max_seq_len):
  x_lin = Linear(dim)(x[t])
  h_prev = h
  h = Linear(dim)(x_lin + h_prev)
  out.append(h)

h  # final state
out  # shape {time, batch, dim}

Current design:

There is Loop() which can be used in a with context, which corresponds to the for-loop in the example, or in general to a while-loop. Like:

with Loop() as loop:
  ...

There is State() which can define hidden state (for any module or any code).

The example above can be written as:

h = State({batch, dim}, initial=0)
with Loop() as loop:  # this introduces a new loop
  x_t = loop.unstack(x)  # shape {batch, dim}

  x_lin = Linear(dim)(x_t)
  h_prev = h.get()
  h_ = Linear(dim)(x_lin + h_prev)  # shape {batch, dim}
  h.assign(h_)

  out = loop.stack(h_)  # shape {time,batch,dim}
  h_last = loop.last(h_)

# h.get() would now return the last state
# h_last is an alternative

Or with a module as:

class MyRec(Module):
  def __init__(self):
    super().__init__()
    self.x_linear = Linear(dim)
    self.h_linear = Linear(dim)
    self.h = State({batch, dim}, initial=0)

  def forward(self, x):
    # x shape is {batch, dim}
    x_lin = self.x_linear(x)
    h_prev = self.h.get()
    h = self.h_linear(x_lin + h_prev)  # shape {batch, dim}
    self.h.assign(h)
    return h

rec = MyRec()
with Loop() as loop:  # this introduces a new loop
  x_t = loop.unstack(x)  # shape {batch, dim}
  h_ = rec(x_t)  # shape {batch,dim}. this represents the inner value
  h = loop.last(h_)  # shape {batch,dim}
  out = loop.stack(h_)  # shape {time,batch,dim}

For the TF name scopes (and variable scopes), we should follow #25, i.e. make it exactly as the module hierarchy.

The RETURNN layer name of the created RecLayer via Loop does not matter too much. It could be arbitrary, or some clever (but simple) logic to use the first module name or so. The RETURNN layer hierarchy can be independent from the actual TF name scopes (via #25).

Special options for the RecLayer like include_eos can be options for Loop, like Loop(include_eos=True). Or as a method, like loop.set_include_eos(True).

Loop (potential) methods:

State has methods get and assign. (... See discussion below for more ...)

Current reasonings:

Why no special base class Rec which derives from Module? We want to easily allow to use any kind of module inside a loop. We think the current API makes this more straight-forward.

Why is h not an argument of forward, and why State instead? This allows to call other sub modules, which might define their own hidden state. So the root recurrent module does not need to know about all the hidden states of sub modules.

Why to have the hidden state explicit, and not use sth more close to self.prev? To make the behavior more straight-forward.

The current design allows for nested loops and sub modules with hidden state. Only the Loop() call actually introduces a new loop.

class MySubRec(Module):
  def __init__(self):
    super().__init__()
    self.h = State({batch,dim})

  def forward(self, a):
    # assume a shape {batch,dim}
    h = self.h.get() + a
    self.h.assign(h)
    return h

class MyRec(Module):
  def __init__(self):
    super().__init__()
    self.sub = MySubRec()
    self.h = State({batch,dim})

  def forward(self, x):
    a = self.h.get() + x

    # example with sub as nested loop
    with Loop() as loop:
      y = self.sub(a)
      y = loop.last(y)

    # or: example with sub in same loop
    y = self.sub(a)

    self.h.assign(y)
    return y

There should not be any special handling needed for the Choice layer. Note that the search flag and train flag logic is a separate thing (#18).

There should not be any special handling needed whether the input to a rec module call would be inside the current/same loop or not. unstack on some value which is already inside the loop would not make sense, though, and should result in an error. But this would all be covered by RETURNN logic already.

RETURNN rec automatic optimization should not cause any problems. RETURNN already should guarantee that it is equivalent. From the user view point, it never ever should matter whether it is optimized. Otherwise this is rwth-i6/returnn#573. On this returnn-common level, it should not matter.


Example for LSTM for a single step:

class Lstm(Module):
  def __init__(self):
    super().__init__()
    self.h = State({batch,dim})
    self.c = State({batch,dim})
    self.ff_linear = Linear(dim * 4)
    self.rec_linear = Linear(dim * 4)

  def forward(self, x):
    # x shape is {batch,dim} (single frame)
    x_ = self.ff_linear(x)
    h_ = self.rec_linear(self.h.get())
    x_in, g_in, g_forget, g_out = split(x_ + h_, 4)
    c = self.c.get() * sigmoid(g_forget) + tanh(x_in) * sigmoid(g_in)
    self.c.assign(c)
    h = tanh(c) * sigmoid(g_out)
    self.h.assign(h)
    return h
albertz commented 3 years ago

The example should be extended:

How are nested loops handled? This should be straight-forward as well. And nothing must be ambiguous about it, e.g. the (implicit or explicit) definition and handling of hidden state.

albertz commented 3 years ago

What are reasonable sub modules with hidden state? How would you define them? Just normal Modules? Or other Recs where you call step?

albertz commented 3 years ago

Custom ending condition, how to define?

albertz commented 3 years ago

How to get the last hidden state?

albertz commented 3 years ago

Also related: https://github.com/rwth-i6/returnn/issues/391 https://github.com/rwth-i6/returnn/pull/545

I.e. the question about how to define masked self-attention in a generic/flexible way in this Rec concept. E.g. HistoryMaskLayer or so. This could all be wrapped directly but is this intuitive then? How does it look like? Can we make it more intuitive?

albertz commented 3 years ago

Also: Is there anything special needed for Choice (ChoiceLayer)? And beam search? Maybe not and it is all straight-forward already.

How to define whether search is enabled? Just not do any special handling on this level at all and leave it to the current RETURNN behavior? Or make it more explicit?

Atticus1806 commented 3 years ago

How to define whether search is enabled? Just not do any special handling on this level at all and leave it to the current RETURNN behavior? Or make it more explicit?

From what I can see RETURNN handling should be enough. Assuming using this with automation changes in config can be done elsewhere I think.

Custom ending condition, how to define?

Maybe give an optional parameter for Rec like Rec.custom_break which will be checked internally at the beginning of the loop?

self.h = self.state({batch, dim})

How do you concretely think of this init? This requires an explicit value for the batch dimension right? I am not sure if this is what we should aim for, RETURNN should be able to handle this already? Or is this right now just for example sake?


with rec.loop() as loop:  # this introduces a new loop
 h_ = rec(x)  # shape {batch,dim}. this represents the inner value
 h = loop.last(h_)  # shape {batch,dim}
 out = loop.stack(h_)  # shape {time,batch,dim}

Maybe this example (esp. for the doc.) could be extended with another (non-rec) module, showing how to use it in a Module, so that someone who did not work with Rec in this way yet directly has an example which can be modified. Maybe something like (assuming this is how its intended):

def EncDec(Module):
  def __init__():
    self.enc = Linear(dim)
    self.dec = MyRec()

  def forward(x):
    x = self.enc(x)
    with self.dec.loop():
        h_ = rec(x)  # shape {batch,dim}. this represents the inner value
        h = loop.last(h_)  # shape {batch,dim}
        out = loop.stack(h_)  # shape {time,batch,dim}
    return out

Also: Is there anything special needed for Choice (ChoiceLayer)? And beam search? Maybe not and it is all straight-forward already.

We are mainly aiming for conversion to RETURNN here, so maybe if we choose a smart way to include them in the loop and then some internal handling this can be converted (since in returnn they are part of the rec unit anyways).

albertz commented 3 years ago

How to define whether search is enabled? Just not do any special handling on this level at all and leave it to the current RETURNN behavior? Or make it more explicit?

From what I can see RETURNN handling should be enough. Assuming using this with automation changes in config can be done elsewhere I think.

Note that this is not just about what is enough to be able to define all networks. Or not sure how you mean it.

It's about being straight forward and clear. I.e. at no time, it should be unclear to the user when search is used.

We do not have to follow exactly the behavior of RETURNN. There are also multiple ways in RETURNN. We can restrict it to one clean way. We can also change it. Or introduce a simpler variant here.

I'm tending to make it explicit. But not sure.

PyTorch also has a similar concept for the train flag (as we do have as well in RETURNN). I.e. some PyTorch modules behave differently depending if they are in train or eval mode (e.g. Dropout). We have exactly the same in RETURNN. And search is a flag like train.

The difference is how these flags are set:

Maybe again, here in returnn-common, we can follow the PyTorch style a bit for this, and also copy it for the search flag? Not sure...

Edit: I think this is actually a separate thing, which we can discuss/handle independently. So I opened #18 about this. And this can be marked as off-topic here.

albertz commented 3 years ago

Custom ending condition, how to define?

Maybe give an optional parameter for Rec like Rec.custom_break which will be checked internally at the beginning of the loop?

In RETURNN, the layer is called end. So maybe additionally to step, the user can define a method end (kind of similar as tf.while_loop gets body and cond).

This end method can be optional. E.g. when unstack is used at some point, then the sequence length is implied from this. The same is true when cross-entropy loss is used, as unstack would be used on the targets.

albertz commented 3 years ago

self.h = self.state({batch, dim})

How do you concretely think of this init? This requires an explicit value for the batch dimension right? I am not sure if this is what we should aim for, RETURNN should be able to handle this already? Or is this right now just for example sake?

batch here would be the dimension tag (#17). I.e. this is kind of a placeholder. This does not require an explicit value. And yes, RETURNN has all that logic already.

albertz commented 3 years ago
with rec.loop() as loop:  # this introduces a new loop
 h_ = rec(x)  # shape {batch,dim}. this represents the inner value
 h = loop.last(h_)  # shape {batch,dim}
 out = loop.stack(h_)  # shape {time,batch,dim}

Maybe this example (esp. for the doc.) could be extended with another (non-rec) module, showing how to use it in a Module, so that someone who did not work with Rec in this way yet directly has an example which can be modified. Maybe something like (assuming this is how its intended):

def EncDec(Module):
  def __init__():
    self.enc = Linear(dim)
    self.dec = MyRec()

  def forward(x):
    x = self.enc(x)
    with self.dec.loop():
        h_ = rec(x)  # shape {batch,dim}. this represents the inner value
        h = loop.last(h_)  # shape {batch,dim}
        out = loop.stack(h_)  # shape {time,batch,dim}
    return out

Yes sure. Although I thought this part was clear.

albertz commented 3 years ago

Also: Is there anything special needed for Choice (ChoiceLayer)? And beam search? Maybe not and it is all straight-forward already.

We are mainly aiming for conversion to RETURNN here, so maybe if we choose a smart way to include them in the loop and then some internal handling this can be converted (since in returnn they are part of the rec unit anyways).

No, simplicity, clear and straight-forward behavior has priority over everything. In the end, everything can be converted to RETURNN, because RETURNN should be generic to allow for everything (and if that is not the case, we should simply extend RETURNN itself). Esp if the RETURNN behavior is unintuitive or not straight-forward, we should not just copy that here.

I'm not sure what you mean by "smart way", and "internal handling" for what?

Right now, you would just use Choice as you would use other layers like Linear. I was wondering whether we need any special behavior for Choice here in returnn-common, or not. Currently I don't really see a reason for special behavior. But maybe I'm overlooking sth.

This is of course related to the search aspect. Although, even when we can set the search-flag now explicitly in returnn-common, the logic how the ChoiceLayer behaves with search-flag or without, that is already all part of RETURNN, so no special handling needed for that here.

albertz commented 3 years ago

Note: I changed the original suggestion of automatically stacking the output, and automatically implying the construction of the loop, because it was not clear for sub modules with hidden state when they should create a new nested loop or reuse the same loop. So now this is explicit, via loop().

However, I'm not sure if this is now the most intuitiv, clean, straight-forward way. Although other ways I can think of right now are also not really better.

When we keep using Module everywhere and not use Rec as a base class, and Rec.State() and Rec.Loop() explicitly, this is would work already in this draft. (We can still use Rec as a namespace to cover all rec related logic. Or maybe another Python module/package for this. Or just use RecLoop() instead of Rec.Loop(). But this is not relevant for now.) It's just a technical question then about how to handle this code:

# Definition for a single step:
class MyRec(Module):
  ...

# Now the loop:
rec = MyRec()
with Rec.Loop() as loop:
  x_ = loop.unstack(x)
  y_ = rec(x_)
  y = loop.stack(y_)

Or, alternatively to Rec.Loop(), maybe Rec stays a Module where you define the inner step, which would just exactly contain that code:

# Definition for a single step:
class MyRec(Module):
  ...

# The loop:
class MyRecLoop(Rec):
  def __init__(self):
    self.rec = MyRec()

  def step(self, x):
    x_ = self.unstack(x)
    y_ = self.rec(x_)
    return self.stack(y_)

loop = MyRecLoop()
y = loop(x)
Atticus1806 commented 3 years ago

Currently I don't really see a reason for special behavior. But maybe I'm overlooking sth.

I agree with that, right now it does not seem like we need it, but again I can report back once I finished my config wether there is some weird behavior we didnt think of.

I'm not sure what you mean by "smart way", and "internal handling" for what?

That was weird phrasing, what I meant pretty much is that I would have all of that Choice and beam search handling in ´returnn_common´ and not leave it to the user to find the "correct" way to use the layers, best case it should just be used as you said similar to a Linear layer imo.

albertz commented 3 years ago

Also, in some earlier draft, there was a separate init (or initial_state, or step0) method to define the initial state, next to step, with the same arguments as step (i.e. input x in the example). I changed that now, and merge the logic into step, because:

However, having both logic merged in step also has downsides. E.g. maybe this is a bit unintuitive.

Also, I wonder whether it make sense to decouple the initial state logic more. Maybe for some given rec module, the user later wants to overwrite just the initial state with some custom logic. This is an important aspect. This aspect is also relevant for rec sub modules. This can be recursive.

Edit: Ok, we can simply allow that init of a state (e.g. rec.h.init(...)) can be called multiple times. Also, the loop object can collect all state vars and allow for easy overwriting of all initial state vars recursively. Then is is basically solved.

albertz commented 3 years ago

Do we require from the user that a rec module is written in a way that it can operate automatically both on a frame-by-frame level (this is the natural way) and also on a whole sequence-level?

Or do we actually require anything special for this, or this is anyway already all automatic?

Most (all?) builtin modules (e.g. Linear etc) already handle this fine. This is also for modules with internal state like Lstm. Although we should maybe differ between modules where it does not matter at all (e.g. Linear) and which really have different implementations and logic (Lstm). Maybe this is actually bad, e.g. in the case of Lstm, that this has different automatic implied behavior? The current behavior is implied whether there is a time dim (assumes outside loop, operates on the time dim) or not (assumes inside loop). This might be problematic for nested loops. But maybe not. I think it's actually fine when a module/layer always derives the behavior from the shape, e.g. whether there is a time dim, and independent of the context (whether in a loop or not). And this is maybe already the case. Or if not, would cause errors/exceptions, and these can be fixed individually, on RETURNN side.

What happens though when a rec module is called without an outer loop? It would simply just calculate one step, on the input? And use the defined initial state? Is there any problem with that? What happens on unstack?

Consider this example for a LSTM on a single step:

class Lstm(Rec):
  def __init__(self):
    super().__init__()
    self.h = self.state({batch,dim})
    self.c = self.state({batch,dim})
    self.ff_linear = Linear(dim * 4)
    self.rec_linear = Linear(dim * 4)

  def step(self, x):
    # x shape is {batch,dim} (single frame)
    x_ = self.ff_linear(x)
    h_ = self.rec_linear(self.h.prev())
    x_in, g_in, g_forget, g_out = split(x_ + h_, 4)
    c = self.c.prev() * sigmoid(g_forget) + tanh(x_in) * sigmoid(g_in)
    self.c.assign(c)
    h = tanh(c) * sigmoid(g_out)
    self.h.assign(h)
    return h

Now when called with an input which has a time-dim, should that automatically introduce the loop around it? No, there should not be any fancy automatic implied logic. The automatic optimization would happen on RETURNN side. I think this should be up to the user, to explicitly add code for that.

To answer the original question: No, we do not require this.

But how is the behavior of state.h.prev(), state.h.init(), state.h.assign() defined exactly in each case? I.e. without any loop, with a loop, or multiple nested loops. Or with multiple calls to the module?

To extend the original question: Do we allow this? Can the user write this Lstm module such that it works both for a single step and a whole sequence?

albertz commented 3 years ago

The input to step for rec modules is currently defined to be outside the loop, or to be the same in each frame. Or is it actually? In the example, x is outside the loop. But is that relevant for the loop (except if the user calls unstack on it)?

For rec sub modules, the input could also be some value from within the current loop (as in the example for rec sub modules).

Do we need to take any special care here? Or is all just fine?

I think this is all fine. I don't see how this has any relevance.

albertz commented 3 years ago

Why do we need self.state, and not a global Rec.State? The state would always be related to the most inner loop after a rec module call (where the state assign happens).

albertz commented 3 years ago

Do we always need the Rec base class? We could use a normal Module, and just forward instead of step.

Also a Module might define states (via Rec.State). Whether it has "state", this can not depend on the input, because the state is defined in __init__. Although, whether and how the state is used, this can be dynamic in the forward function.

E.g. a Lstm step as a module for a single step:

class Lstm(Module):
  def __init__(self):
    super().__init__()
    self.h = Rec.State({batch,dim})
    self.c = Rec.State({batch,dim})
    self.ff_linear = Linear(dim * 4)
    self.rec_linear = Linear(dim * 4)

  def forward(self, x):
    # x shape is {batch,dim} (single frame)
    x_ = self.ff_linear(x)
    h_ = self.rec_linear(self.h.prev())
    x_in, g_in, g_forget, g_out = split(x_ + h_, 4)
    c = self.c.prev() * sigmoid(g_forget) + tanh(x_in) * sigmoid(g_in)
    self.c.assign(c)
    h = tanh(c) * sigmoid(g_out)
    self.h.assign(h)
    return h

lstm = Lstm()
with Rec.Loop() as loop:
  x_ = loop.unstack(x)
  y_ = lstm(x_)
  y = loop.stack(y_)

How does the Rec base class changes anything here?

albertz commented 3 years ago

We could also allow a loop around a normal Module. Sth like:

mod = MyModule()
with Rec.Loop() as loop:
  x_ = loop.unstack(x)
  y_ = mod(x_)
  y = loop.stack(y_)

Although, due to RETURNN automatic optimization, this would be equivalent to:

mod = MyModule()
y = mod(x)

So, if this is possible, this touches on the question whether we need the Rec base class.

The question is if this is always possible or would result in problems.

albertz commented 3 years ago

Note: I changed the original suggestion of automatically stacking the output, and automatically implying the construction of the loop, because it was not clear for sub modules with hidden state when they should create a new nested loop or reuse the same loop. So now this is explicit, via loop().

When we keep using Module everywhere and not use Rec as a base class, and Rec.State() and Rec.Loop() explicitly, this is would work already in this draft. (We can still use Rec as a namespace to cover all rec related logic. Or maybe another Python module/package for this. Or just use RecLoop() instead of Rec.Loop(). But this is not relevant for now.) It's just a technical question then about how to handle this code:

# Definition for a single step:
class MyRec(Module):
  ...

# Now the loop:
rec = MyRec()
with Rec.Loop() as loop:
  x_ = loop.unstack(x)
  y_ = rec(x_)
  y = loop.stack(y_)

Or, alternatively to Rec.Loop(), maybe Rec stays a Module where you define the inner step, ...

Thinking further about this, I now prefer the solution to just have it all a normal Module, no special Rec base module class with step. And then there is the special Rec.State. Or instead of Rec.State, maybe just State. Which every normal Module can use, in the __init__.

And further, I think it is most straight-forward to not have any init logic inside step / forward (there is no special step anymore when we just use Module always).

The reasoning is that I think the easiest way to think about the logic for the user is when with RecLoop() as loop: can be seen as analog/equivalent to a for i in range(...) or while ... loop. There should be no special cases such as "this init calculation would only happen before the first step". The mental model for the user should really exactly allow for this direct translation. From this simple principle, we would infer all the other behavior.

This implies that there should not be any init logic in forward, and it would really just execute a single step when inside the loop.

Further, this State object would have just an assign method. There is no need for a special init. The assign can infer from the scope where this belongs to (e.g. some init outside the loop). And also, it would just have a get or read method. There is no need for prev because we can simply infer whether we need prev: internally. Thus, if you would want some special init, you would just write it like this:

# Definition for a single step:
class MyRec(Module):
  ...

# Now the loop:
rec = MyRec()
rec.h.assign(...)  # init
with Rec.Loop() as loop:
  x_ = loop.unstack(x)
  y_ = rec(x_)
  y = loop.stack(y_)

Technically, some extra logic is maybe needed to define the RETURNN layer names based on this code. The Rec.Loop() here would create the RecLayer somehow, but this does not really have a name. The rec here might have a name (because it usually would be part of the main module as a sub module). So we could use that name (and scope) if there is just a single module call inside the loop. If there are multiple module calls inside the loop, it is not so straight-forward. Maybe we take the first. Or we somehow introduce anonymous layer names in RETURNN (so the RecLayer would not really have a name and not introduce a new level in the name scope hierarchy). But this aspect is only for the wrapping and layer names, and the user would maybe not so much care about this, but more about the logic of the model / calculation.

On the question whether RETURNN automatic optimization might cause any problems: RETURNN already should guarantee that it is equivalent. From the user view point, it never ever should matter whether it is optimized. Otherwise this is https://github.com/rwth-i6/returnn/issues/573. On this returnn-common level, it should not matter. (Maybe we want to introduce potential optimization also on this higher level. But this would be another separate topic, and we can deal with that later.)

Atticus1806 commented 3 years ago

If there are multiple module calls inside the loop, it is not so straight-forward.

You mean something like:

rec = MyRec()
rec.h.assign(...)  # init
with Rec.Loop() as loop:
  x_ = loop.unstack(x)
  y_ = rec(x_)
  y = loop.stack(y_)
with Rec.Loop() as loop:
  x_ = loop.unstack(y)
  y_ = rec(x_)
  y = loop.stack(y_)

? I am not sure if we should allow this behavior, because I feel like this can cause quite some confusion. I think it would be okay to have the User use a rec only once. Or maybe I am missing the point here.

Maybe we want to introduce potential optimization also on this higher level. But this would be another separate topic, and we can deal with that later.

If we do we should be careful not to have these two optimizations interfere with eachother. The only case I can think of is optimization RETURNN is not able to do because of the config structure, but from what I know there should not be these cases? Otherwise we might optimize something RETURNN would currently optimize and after changing RETURNN behavior (for whatever reason) this might then not be working anymore (since RETURNN would like to optimize differently).

albertz commented 3 years ago

If there are multiple module calls inside the loop, it is not so straight-forward.

You mean something like:

rec = MyRec()
rec.h.assign(...)  # init
with Rec.Loop() as loop:
  x_ = loop.unstack(x)
  y_ = rec(x_)
  y = loop.stack(y_)
with Rec.Loop() as loop:
  x_ = loop.unstack(y)
  y_ = rec(x_)
  y = loop.stack(y_)

No. I was referring to the case of multiple module calls inside a loop, e.g. like:

rec1 = MyRec1()
rec2 = MyRec2()
with Rec.Loop() as loop:
  y1_ = rec1(x)
  y2_ = rec2(y1_)
  y = loop.stack(y2_)

But I was anyway just referring to the question of how this is mapped to RETURNN layers. Which is not so important, just a technical detail.

The (model/computation) behavior in all the examples is very clear. This always follows from the principle I stated above (Rec.Loop() simply corresponds to a Python while-loop).

? I am not sure if we should allow this behavior, because I feel like this can cause quite some confusion.

I don't think there is any confusion? Just think of Rec.Loop() as a while-loop. It's very clear what happens then.

Maybe we want to introduce potential optimization also on this higher level. But this would be another separate topic, and we can deal with that later.

If we do we should be careful not to have these two optimizations interfere with eachother. The only case I can think of is optimization RETURNN is not able to do because of the config structure, but from what I know there should not be these cases? Otherwise we might optimize something RETURNN would currently optimize and after changing RETURNN behavior (for whatever reason) this might then not be working anymore (since RETURNN would like to optimize differently).

I have not really thought this fully through. I was more thinking about the case when the user might want to provide a more efficient implementation for a whole sequence (e.g. for the LSTM example). Or more relevant maybe for self-attention (see also the referenced issues above). But I think we don't have to consider this now.

albertz commented 3 years ago

Custom ending condition, how to define?

Maybe give an optional parameter for Rec like Rec.custom_break which will be checked internally at the beginning of the loop?

In RETURNN, the layer is called end. So maybe additionally to step, the user can define a method end (kind of similar as tf.while_loop gets body and cond).

This end method can be optional. E.g. when unstack is used at some point, then the sequence length is implied from this. The same is true when cross-entropy loss is used, as unstack would be used on the targets.

To update on this: When we now stick to Module as base class (no Rec), I would not use a special additional end method (the Module just has the forward method).

I think I would instead make it an method for the Loop object. Like this:

rec = MyRec()
with Loop() as loop:
  y_ = rec(x)
  y = loop.stack(y_)
  loop.end(y_ == EOS)  # example end condition

There is also the question on the include_eos option of the RecLayer. This might be an option for Loop (like Loop(include_eos=True). Or as a method, like loop.set_include_eos(True).

albertz commented 3 years ago

Also related: rwth-i6/returnn#391 rwth-i6/returnn#545

I.e. the question about how to define masked self-attention in a generic/flexible way in this Rec concept. E.g. HistoryMaskLayer or so. This could all be wrapped directly but is this intuitive then? How does it look like? Can we make it more intuitive?

Extending on this, example draft:

with Loop() as loop:
  # x is [B,D] inside, [T,B,D] outside
  qkv = Linear(2*K+V)(x)  # [B,2*K+V] inside, [T,B,2*K+V] outside
  q, k, v = split(qkv, size_splits=[K,K,V])  # [B,K|V] inside, [T,B,K|V] outside
  k_accum = cum_concat(k)  # [B,T',K] inside, [B,T'',K] outside
  v_accum = cum_concat(v)  # [B,T',V] inside, [B,T'',V] outside
  energy = dot(q, k_accum, red1="static:-1", red2="static:-1", var1="T?", var2="T")  # [B,T] inside, [T,B,T'] outside
  att_weights = softmax_over_spatial(energy)  # [B,T] inside, [T,B,T'] outside
  att = dot(v_accum, att_weights, red1="T", red2="stag:history", var1="static:-1", var2=[])  # [B,V] inside, [T,B,V] outside

Note that the cum_concat behavior is conceptually similar to our loop.stack. Maybe this should be unified?

Note that we have the principle that the user should not need to think about the automatic optimization (https://github.com/rwth-i6/returnn/issues/573).

The axes descriptions and axes themselves still need to be worked out (see referenced RETURNN issue above). softmax_over_spatial (SoftmaxOverSpatialLayer) needs to be clever about the history time axis.

albertz commented 3 years ago

We should also work out how masked computation on the example of transducer with SlowRNN and FastRNN looks like. Again, in this example, what we want:

Example code draft:

x  # shape {batch,enc_time,dim}
slow_rnn = SlowRNN()
fast_rnn = FastRNN()
blank_pred = Linear(1)
non_blank_pred = Linear(...)
t = State({Batch}, dtype=int32, initial=0)
align_label = State({Batch}, dtype=int32, initial=0)
with Loop() as loop:  # over alignment labels
  x_t = x[t]  # shape {batch,dim}
  with MaskedComputation(mask=(align_label.get() != BLANK)):
    slow = slow_rnn(align_label.get(), x_t)
  fast = fast_rnn(align_label.get(), x_t, slow)
  blank_pred_energy = blank_pred(fast)
  log_prob_blank = log_sigmoid(blank_pred_energy)
  log_prob_not_blank = log_sigmoid(-blank_pred_energy)
  log_prob_non_blank_labels = log_softmax(non_blank_pred(fast))
  log_prob_combined = concat(log_prob_non_blank_labels + log_prob_not_blank, log_prob_blank)
  align_label.assign(choice(log_prob_combined, input_type="log_prob"))
  loop.end(t >= x.seq_len)

This is not so much about MaskedComputation here (see #23 about this) but esp on the question whether there is anything special w.r.t. the Loop concept, or whether this would just work as-is. Maybe that is already the case.

albertz commented 3 years ago

One aspect we should define more clearly is the behavior of State.assign. This currently depends on the context.


Resolved: No.

So far this is clear. However, how about masked computation (#23) or conditional code (#24)? This probably should be disallowed.

Edit Just disallowing this is not really an option. Any custom module could define own hidden state and this would not work then. E.g. in the example above, slow_rnn definitely would have hidden state. Imagine that SlowRNN is like the Lstm defined before.

Edit I think actually the wanted should-be behavior is clear in all cases. Just translate the conditional code with if/else. And the masked computation is kind of like if mask: where else: would just repeat the previous.


Edit On the behavior of State and State.assign: This also goes along with the behavior of State.get. State.assign itself should (must) be allowed everywhere. The sequence of get and assign and their corresponding context matters.

Edit So a bit generalizing and specifying:

State.assign overwrites the previous assign when done in the same context and there was no read (get) in between. Although maybe we don't really need to care about overwrites?

Contexts are hierarchical, and ordered. So far we have Loop (this proposal here), Cond (#24) and MaskedComputation (#23). Maybe more later? In TF, as far as I know, there is only tf.while_loop and tf.cond.

There is an implicit first assign via the initial value. And also this assign is in the root context. So assign is always before read. I.e. for some read, there is always a previous assign, in the same or upper context.

First read in Loop ctx will correspond to the last assign in the ctx.


Resolved: This is wrong

Edit Actually, is there anything special about State? Don't we need the same logic even in general? Consider the following code:

y = const(0)
with Loop():
  y = y + 1
return y

What would you expect? Or consider this:

with Cond(...) as cond:
  y = const(1)
with cond.false_branch():
  y = const(2)
return y

y would be a LayerRef. Accessing it / reading it would then consider its definition and the context of its definition, and translate this to the current context.

This probably cannot work like this because we cannot really catch the y assignment (y = ...), so we also do not know what is a reassignment of the same local Python variable. (It works for object attributes, e.g. like net.y = ..., but not sure if this would be so clear...)

So to answer the question, yes, State is special. There are different ways to design this, though. Maybe State is also a bad name, and it should be StateVar or just Var or so.

State can derive from LayerRef, though. The read (or get) basically corresponds to get_name. Maybe this API should be refactored (renamed) a bit. So we can catch that. And we would add assign.


Edit Back to behavior definition: The fist read (or get_name) inside a Loop ctx would result in "prev:..." iff there is an assign later in the Loop ctx. This implies that the layer dict can not be created right away because we have to wait until the Loop ctx is finished to know whether there is an assign later. Or we just always make it output prev:...? No... We should not make it a recurrent var when not needed.

Maybe the user could make it explicit. Like State.set_stateful() inside the Loop ctx before the first read (get_name). Then it would result in "prev:...". And we assert an assign sometime later. Or if this is not set, we would error on any assign inside the loop. This would allow that the layer dict can be created right away, and would probably be simpler.

Resolved This can maybe also be done automatically on all State members (attribs) of a Module when the module is called inside the loop. Because we assume that a State member of a module would always be assigned within the module call (otherwise it doesn't really make sense that the module defines this as a State).

There is still the question on behavior when there are multiple assigns in the loop. The last assign defines the value for the first read. But how to know which is the last before executing all? Or again let the user make this explicit? Is this always doable? Or maybe also only allowing a single assign inside the loop. And assert that it comes after the read.

However, what if the module is called multiple times? This would result in multiple assign calls. Actually the full set_stateful, read and assign call sequence, maybe within the same ctx. Disallow this? What would be the work-around for the user? Or must we allow this? But then this means that we can only construct the layer dict of the loop and everything in it after the loop construction is finished. Then we maybe also do not need set_stateful. Then instead of get_name, we must use some different mechanism because some explicit read/get call is needed to store the context.

Resolved Is this State only about Loop? When set_stateful is called inside a Cond, what would that mean? This must be allowed, in any ctx. But is the usage pattern always the same? It would be (in the same ctx) set_stateful, read, assign? We require read before assign, and assign to be in the same ctx as set_stateful. We allow read at any other place as well. We also allow assign before set_stateful in an outer ctx, which would define/overwrite the initial state. Yes, State is only about Loop (even though it also might require further logic or Cond etc), and specifically to handle the "prev:..." logic of RETURNN. For other purpose (e.g. returning something from Cond), there would be other separate ways.

albertz commented 3 years ago

Ok, posting as a new cleaned up comment on State, summarizing from the previous discussion/thoughts.

State is only about Loop (even though it also might require further logic or Cond etc), and specifically to handle the "prev:..." logic of RETURNN. For other purpose (e.g. returning something from Cond, see #24), there would be other separate ways. Contexts are hierarchical, and ordered. So far we have Loop (this proposal here), Cond (#24) and MaskedComputation (#23). Maybe more later? In TF, as far as I know, there is only tf.while_loop and tf.cond.

State would not derive from LayerRef but the ILayerMaker would know how to handle State (so the user does not need to call State.get explicitly).

State methods:

assign and read are allowed in all context (Loop, nested Loop, Cond, etc). However, not all possible sequence of calls (with mixing different contexts) are allowed.

Behavior of State.assign. This currently depends on the context.

State.assign overwrites the previous assign when done in the same context and there was no read (get) in between. Although maybe we don't really need to care about overwrites?

There is an implicit first assign via the initial value. And also this assign is in the root context. So assign is always before read. I.e. for some read, there is always a previous assign, in the same or upper context.

First read in Loop ctx will correspond to the last assign in the ctx.

The fist read (or get_name) inside a Loop ctx would result in "prev:..." iff there is an assign later in the Loop ctx. This implies that the layer dict can not be created right away because we have to wait until the Loop ctx is finished to know whether there is an assign later.

It must be possible to call a module is called multiple times. This would result in multiple assign calls. Actually the full (set_stateful), read and assign call sequence, maybe within the same ctx. This means that we can only construct the layer dict of the loop and everything in it after the loop construction is finished. This implies that we do not need set_stateful. Then instead of get_name, we must use some different mechanism because some explicit read/get call is needed to store the context.

Atticus1806 commented 3 years ago

I am just now reading into the updates from the last few days, so maybe I missed it: Is there a concept for setting a target of the rec unit yet?

Atticus1806 commented 3 years ago

Now looking at the LSTM example, if I understand correctly self.c.prev() would result in referencing c. So pretty much what we have is a variable, most likely called the same as the layer itself. Sidequestion: Are there any cases where this same naming is unlikely/harmful? Otherwise, it kind of feels like we introduce a new variable self.c = State({batch,dim}) and 2 parts of code self.c.prev() and self.c.assign(c) for a statement which from reading the code semantically just means: we want to access the previous state of this layer and would in the end just write: prev:c in the config. Of course I realized that otherwise accessing the .prev() before the layer itself might be difficult. Right now I cannot think of a more simple direct way, but I wanted to bring this up for discussion again since this feels a bit much to me and being as straightforward as possible is one of the core goals of the project.

In general: do we really need this complicated assign logic? Shouldn't it just have two values: Initial and then the corresponding layer? Because I dont think that prev value (for writing the config) should change once it is set for one layer. Or am I missunderstanding here?

albertz commented 3 years ago

Is there a concept for setting a target of the rec unit yet?

I'm not sure I understand. What do you mean or intend? Define a loss inside the rec layer? Why is there anything special about it? This would follow just the same logic as for other losses (see e.g. #9). You would define CE loss based on data:classes. It roughly would look like this:

x = ...
targets = get_extern_data("classes")
with Loop() as loop:
  ...
  prob = ...
  if train:
    targets_ = loop.unstack(targets)
    loss = ce(targets_, prob)
    loss.mark_as_loss()

You would not use the target option of RETURNN layers.

(Maybe this example should also be moved to the initial post above.)

albertz commented 3 years ago

Now looking at the LSTM example, if I understand correctly self.c.prev() would result in referencing c. So pretty much what we have is a variable, most likely called the same as the layer itself.

Yes, a local variable in the Python sense.

Note that prev was changed to get or read in some of the later discussion. I just forgot to update the initial example. I think get makes this even more clear that this behaves logically like a variable. The logic with "prev:..." in RETURNN is hidden here and would be applied automatically when needed (which is only when this is used inside a Rec loop, for the first get).

About the naming, I'm not exactly sure. Maybe. Probably. But not so important. I'm making #25 a priority here.

Sidequestion: Are there any cases where this same naming is unlikely/harmful? Otherwise, it kind of feels like we introduce a new variable self.c = State({batch,dim}) and 2 parts of code self.c.prev() and self.c.assign(c) for a statement which from reading the code semantically just means: we want to access the previous state of this layer and would in the end just write: prev:c in the config. Of course I realized that otherwise accessing the .prev() before the layer itself might be difficult. Right now I cannot think of a more simple direct way, but I wanted to bring this up for discussion again since this feels a bit much to me and being as straightforward as possible is one of the core goals of the project.

I don't quite understand what you mean. Are you arguing about the naming? Or what do you think is too complicated? To me this looks very simple and straight-forward now. Compare the very first example in the initial post. The standard Python code and our rec loop code look conceptually very similar, and both just equally simple.

In general: do we really need this complicated assign logic? Shouldn't it just have two values: Initial and then the corresponding layer? Because I dont think that prev value (for writing the config) should change once it is set for one layer. Or am I missunderstanding here?

As said, it behaves logically like a Python variable. You can assign a Python variable multiple times. x = 1; ...; x = 2; ... is valid code.

You could say, ok we disallow multiple assigns to keep it simple. Only one single assign.

But then when the user writes e.g. this in the config:

lstm = Lstm(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

This would not work because the second call to lstm would assign the same state var a second time. Or you somehow introduce a separate/independent state var for each lstm call here. But then it would also not behave as you would naturally expect. So, to make behavior correct and as expected, you need to allow multiple assign calls.

This is the goal. To have the behavior straight-forward and as expectable, with a simple mental model.

albertz commented 3 years ago

Note that I'm mostly working now on the generalized self attention on RETURNN side (https://github.com/rwth-i6/returnn/issues/391) and related (required) things like extended dimension tag support (e.g. https://github.com/rwth-i6/returnn/pull/577).

I want to have this generalized self attention in RETURNN ready to see if this also fits nicely into this rec loop concept or if we should adapt or change sth. The generalized self attention in RETURNN might introduce some new tricky things like a special dim tag for inside the loop (rec-history dim tag).

The extended dimension tag support in general in RETURNN will also become relevant here because in returnn-common, I want to put much more emphasis on this, and maybe even disallow (or deprecate) any other way. See #17.

Atticus1806 commented 3 years ago

What do you mean or intend?

When not setting target for the Rec unit in the current config this causes an error with the end layer during training, that batch shapes (None, ) and (None, 1) do not match.

I don't quite understand what you mean. Are you arguing about the naming? Or what do you think is too complicated?

I was thinking each prev: later in the RETURNN config introduces at least 3 new lines of code, which in total would sum up quite quickly. But now I understand that you want this State to be more abstract and more powerful than just tracking the prev case, in that case these 3 lines are not much, I agree with you. Most likely my problem right now is, I dont really see how this would translate to a RETURNN config. I can see that you try to do more with State here, but right now I am not sure what for, which if I understand you correctly is the current goal .Can you maybe give an example where we would want to do more with the State?

This would not work because the second call to lstm would assign the same state var a second time. Or you somehow introduce a separate/independent state var for each lstm call here. But then it would also not behave as you would naturally expect. So, to make behavior correct and as expected, you need to allow multiple assign calls.

I am not sure wether I understand this example correctly. Right here we dont have a state variable, so why would we assign it twice if we only call the assign command once? Or do you mean something like (assuming the state var is self.state) where we do self.state.assign(layer1) and then self.state.assign(layer2). In that case I would argue, that reusing the same state for two different layers might also make things complicated. Most likely I am missing something here.

albertz commented 3 years ago

What do you mean or intend?

When not setting target for the Rec unit in the current config this causes an error with the end layer during training, that batch shapes (None, ) and (None, 1) do not match.

I don't really understand. It sounds like a bug. But anyway, things like this are really irrelevant for the discussion here. We should just fix or extend RETURNN in any way needed.

I was thinking each prev: later in the RETURNN config introduces at least 3 new lines of code,

Exactly like in the pure Python code in the beginning. Which is extremely simpel?

which in total would sum up quite quickly.

How? The user probably rarely would write such code. Most States would be introduced by some existing Modules and the user would just use those Modules.

Most likely my problem right now is, I dont really see how this would translate to a RETURNN config.

Well, speaking purely abstractly: RETURNN is generic enough to allow just any possible construction, or if not, it should be, and should be extended to allow that.

So this should never be a concern here in this repo (how it translates to a RETURNN config). The main goal is to have it logical, simple, straight-forward, expectable. How it translates to RETURNN is an implementation detail, after we have clarified what design we want.

Speaking more specifically now, how we actually do the implementation: I think most aspects were already discussed here (the thread gets too long to easily see that...). There might be some open questions but I think they can be solved easily.

I will start working on this once I finished the generalized self attention logic in RETURNN.

I can see that you try to do more with State here, but right now I am not sure what for, which if I understand you correctly is the current goal .Can you maybe give an example where we would want to do more with the State?

I just gave you one? This is one such example:

lstm = Lstm(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

I am not sure wether I understand this example correctly. Right here we dont have a state variable, so why would we assign it twice if we only call the assign command once?

We have. The Lstm introduces it. See the Lstm definition above. E.g. there is then lstm.c. And the lstm(...) call results in lstm.c.assign.

This has multiple state assign calls because you call the same module multiple times which reuses the same hidden state.

Although, maybe this is actually not such a good example. I wonder now. I'm thinking about e.g. Universal Transformer now, where you also would call the same self attention layer multiple times. But there you do not want to have it use the same hidden state, but every call (layer) should still have its own hidden state, just the parameters should be shared.

So maybe the idea that a module can have hidden state is not good after all? Again, speaking of PyTorch, PyTorch modules can have that as well (as buffers), but usually it is never done this way. E.g. the LSTM module in PyTorch or LSTMCell module in PyTorch explicitly gets the prev state and cell as arguments and returns the new state and cell.

Maybe we should somehow make it more explicit, what hidden state is being used? Maybe like this:

lstm = Lstm(...)
with Loop() as loop:
  with StateScope():
    layer1 = lstm(loop.unstack(x))
  with StateScope():
    layer2 = lstm(layer1)
  output = loop.stack(layer2)

Maybe we also should differentiate between buffers and state? Buffers (in general, and also in PyTorch) are intended to not really have influence on the behavior. Or not on the model behavior at least. I'm actually not sure sure on the typical usage of buffers in PyTorch. But they are definitely not used as hidden state. Hidden state is always passed explicitly.

Atticus1806 commented 3 years ago

See the Lstm definition above.

Okay no clue how I missed that this references the definition of LSTM you gave here, handled it like a blackbox... now things make a lot lot more sense. Thats I think what caused most of the confusion.

So this should never be a concern here in this repo (how it translates to a RETURNN config). The main goal is to have it logical, simple, straight-forward, expectable. How it translates to RETURNN is an implementation detail, after we have clarified what design we want.

I see, then I think I do get it now. So (to verify I am not mistaken) pretty much the state is a concept of returnn common which then in some cases might translate to prev: in RETURNN, but is a lot more powerful and could do a lot more than just that. Prev: is only one of the possible cases State can be used in. Using this definition I agree with you, these lines code for this is simple a straight forward.

Maybe we should somehow make it more explicit, what hidden state is being used?

I dont know. If we really want to end up with a module like structure I feel like hidden states should be something the user should not really need to deal with himself if he is not changing basic configuration of that module (which shouldnt happen too often, because users usually should put their model together from modules without needing to make too many new ones imo). Adding something like with StateScope() would, if not needed for some logic/expressiveness make it more confusing I think.

albertz commented 3 years ago

So (to verify I am not mistaken) pretty much the state is a concept of returnn common which then in some cases might translate to prev: in RETURNN, but is a lot more powerful and could do a lot more than just that. Prev: is only one of the possible cases State can be used in. Using this definition I agree with you, these lines code for this is simple a straight forward.

Yes. In actually most of the common cases here in returnn-common it would result in "prev:...". I'm more thinking about the maybe somewhat unusual cases.

Maybe we should somehow make it more explicit, what hidden state is being used?

I dont know. If we really want to end up with a module like structure I feel like hidden states should be something the user should not really need to deal with himself if he is not changing basic configuration of that module (which shouldnt happen too often, because users usually should put their model together from modules without needing to make too many new ones imo). Adding something like with StateScope() would, if not needed for some logic/expressiveness make it more confusing I think.

There are multiple aspects/concepts:

So the model parameters are shared in this example:

lstm = Lstm(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

In the current proposal, also the hidden state is shared in this example.

However, maybe the more common use case is that the user wants to share the parameters but not the hidden state. Again, I'm thinking about Universal Transformer. But also most other example I can think of where parameters are to be shared, you would not want to share the hidden state.

I think it should be simple to write code for the common cases, while also allowing for exotic things, while also being straight-forward/logical.

How would you actually implement this common case, where you want to share parameters but not the hidden state? Maybe like this:

lstm = Lstm(...)
with Loop() as loop:
  # Introduce separate state vars such that each layer has own hidden states.
  # Copy state var now to have the right initial state.
  h1, c1 = lstm.h.copy(), lstm.c.copy()
  h2, c2 = lstm.h.copy(), lstm.c.copy()
  lstm.c, lstm.h = h1, c1  # assign own hidden state vars
  layer1 = lstm(loop.unstack(x))
  lstm.c, lstm.h = h1, c1  # assign own hidden state vars
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

I'm not sure if this is still so easy. Or I would rather say no.

Note, in PyTorch, this (sharing parameters but not hidden state) would look like:

lstm = LstmCell(...)
layer1_state = lstm.initial_state()
layer2_state = lstm.initial_state()
for x_ in x:
  layer1, layer1_state = lstm(x, layer1_state)
  layer2, layer2_state = lstm(layer1, layer2_state)
  output = loop.stack(layer2)

In PyTorch, sharing parameters and hidden state would look like:

lstm = LstmCell(...)
lstm_state = lstm.initial_state()
for x_ in x:
  layer1, lstm_state = lstm(x, lstm_state)
  layer2, lstm_state = lstm(layer1, lstm_state)
  output = loop.stack(layer2)

In both cases, you make the handling of state explicit. So there is no confusion on the behavior of state, because it is always explicit.

So, I'm wondering if we also should make it always explicit. But still in a generic way. That is why I proposed StateScope. Maybe StateScope should be passed to the module call via state=StateScope() instead of using with here. Or maybe call it StateHolder or so. Which can hold any nested structure of state vars. The idea with with was that it would be nested and automatically use for any nested sub module calls. I agree that the initial example code for StateScope above is not so clear. This should be worked out.

This is all about parameter sharing, i.e. calling the same module multiple times. And how the hidden state should be handled in this case. Because we hide the hidden state away in the current proposal, such module call has side effects on internal state. Generally speaking, side effects are bad, because it makes it more difficult to reason about the behavior of code. Usually you want that some call like lstm(...) does not have side effects.

Btw, this argument is also about module calls in general. Not really about rec loop actually. E.g. this code in the current proposal:

lstm = LstmOnSeq(...)
layer1 = lstm(x)
layer2 = lstm(layer1)
output = layer2

It again would share the parameters (as expected) but also the hidden state (not sure if that is expected, although it follows logically from the current proposal).

In PyTorch, the example of sharing the parameters but not the hidden state would look like:

lstm = Lstm(...)
layer1, _ = lstm(x)  # default initial hidden state
layer2, _ = lstm(layer1)  # default initial hidden state
output = layer2
albertz commented 3 years ago

Ok, following from those thoughts, I'm now thinking that we definitely should make it explicit. For all modules which have state (by themselves, or via sub modules). And state is not an attrib of the module but just an explicit parameter and return value of a module call.

There can be reasonable default initial state. E.g. when you have def forward(self, ..., state=None) in a module, you can check if state is None: state = some_default_initial_state() or so. Every such module would return the new state. E.g. like a tuple as in the PyTorch LSTM example.

The example for sharing parameters but not sharing hidden state would look sth like:

lstm = Lstm(...)
layer1_state = StateScope()
layer2_state = StateScope()
with Loop() as loop:
  layer1, layer1_state = lstm(loop.unstack(x), layer1_state)
  layer2, layer2_state = lstm(layer1, layer2_state)
  output = loop.stack(layer2)

Which is very similar to the PyTorch example.

From the point of view of the model implementation, it is a bit strange now that there is a conceptual difference between arguments which are state (and thus not normal LayerRef instances but State or StateScope). And maybe sometimes you want that one argument becomes a state, and maybe in other times a different argument becomes a state, and maybe sometimes you do not even want that any of the arguments are states, e.g. calling lstm(x, (prev_h, prev_c)) where prev_h and prev_c are just other normal LayerRefs.

So this is a problem. I think the model definition should handle it all the same and just expect always LayerRefs (or nested structures of LayerRefs).

But then how to we handle this? Maybe more explicitly:

lstm = Lstm(...)
layer1_state = State(lstm.initial_state())
layer2_state = State(lstm.initial_state())
with Loop() as loop:
  layer1, layer1_state_ = lstm(loop.unstack(x), layer1_state.get())
  layer1_state.assign(layer1_state_)
  layer2, layer2_state_ = lstm(layer1, layer2_state.get())
  layer2_state.assign(layer2_state_)
  output = loop.stack(layer2)

State here would be a bit extended that it can also handle nested structures. lstm.initial_state would still return a normal LayerRef.

I think this would be very logical and straight-forward now again. And all is explicit and behavior is always clean.

Other modules (e.g. Lstm etc) would never deal with the State concept. The only point where you deal with that is when you write a rec loop, i.e. Loop(). Which is probably also encapsulated away in some module, like Decoder, and the user in most common cases probably just writes sth like:

encoder = Encoder(...)
decoder = Decoder(...)
output = decoder(encoder(...))

I wonder a bit if this is too complicated now, to write the Loop and explicitly handle the state this way. But I'm not really sure how to make this simpler.

Or maybe like:

lstm = Lstm(...)
loop_state = State()
loop_state.layer1 = lstm.initial_state()
loop_state.layer2 = lstm.initial_state()
with Loop() as loop:
  layer1, loop_state.layer1 = lstm(loop.unstack(x), loop_state.layer1)
  layer2, loop_state.layer2 = lstm(layer1, loop_state.layer2)
  output = loop.stack(layer2)

This would add a bit of clever handling into the State object. Basically the assign and get calls here are covered by __setattr__ and __getattr__. This allows to write the loop code a bit shorter and maybe more readable.

albertz commented 3 years ago

I was quite busy with all the recent dim tag work (https://github.com/rwth-i6/returnn/pull/577) and generalized self attention (https://github.com/rwth-i6/returnn/issues/391). Generalized self attention is finished now, and the concept of consistent dim tags has been improved and extended a lot. Although there are still on-going discussions on taking this even further (https://github.com/rwth-i6/returnn/issues/597, https://github.com/rwth-i6/returnn/issues/632). Some of these might also be relevant for returnn-common but we should discuss this independently, maybe in #17.

I lost a bit track on the current state here, and the reasoning.

I think the final conclusion on the hidden state was to have it always explicit as a (state) argument to the module call, and also explicitly return the new state. So very much like PyTorch. So it is not really hidden at all.

I like it that way because there is no ambiguity in the code, how the hidden state is handled. This is all explicit.

The initial post description is maybe outdated on this. Do we still need to have the state as an attribute to the module (e.g. self.h = State({batch,dim}))? Why?

And why do we need this special State object with assign and get calls? Or the simplified code which uses __setattr__ and __getattr__ on some special State object?

I can see that the state module call argument might not just be a single LayerRef but it could also be some nested structure, esp if this is a module with other sub modules which also could have state. But this does not really explain why we need State. Maybe the concept of the State object was just introduced for the initial idea where we did not want to make this explicit, where the hidden state would have been all implicit and hidden? And now not needed anymore?

Edit The last comment just before this actually says:

state is not an attrib of the module but just an explicit parameter and return value of a module call.

So as I argued above, we would not have this module attrib anymore (like self.h = State({batch,dim})). States are always explicitly passed to a module call, and new states are returned from it.

However, I don't understand this anymore:

From the point of view of the model implementation, it is a bit strange now that there is a conceptual difference between arguments which are state (and thus not normal LayerRef instances but State or StateScope)

Why is there a conceptual difference? Does it need to be? Why? Why can't the state module call argument just be a regular LayerRef (or nested structure)?

Actually I also addressed this before:

I think the model definition should handle it all the same and just expect always LayerRefs (or nested structures of LayerRefs).

But still in this comment I keep State (or StateScope) as a special concept (e.g. loop_state = State()). Why is this needed?

Edit

The first example from above would look like:

lstm = Lstm(...)
loop_state_layer = lstm.initial_state()
with Loop() as loop:
  layer1, loop_state_layer = lstm(loop.unstack(x), loop_state_layer)
  layer2, loop_state_layer = lstm(layer1, loop_state_layer)
  output = loop.stack(layer2)

The second example from above would look like:

lstm = Lstm(...)
loop_state_layer1 = lstm.initial_state()
loop_state_layer2 = lstm.initial_state()
with Loop() as loop:
  layer1, loop_state_layer1 = lstm(loop.unstack(x), loop_state_layer1)
  layer2, loop_state_layer2 = lstm(layer1, loop_state_layer2)
  output = loop.stack(layer2)

The problem is that this does not exactly corresponds to the Python while ...: loop as we want it to w.r.t. the Python local variables. We cannot correctly infer from this Python code that loop_state_layer is a recurrent variable which changes inside the loop. So this is probably one reason for this StateScope as it was suggested.

But the same problem is actually also for any other output. Anything in the loop which wants to use the value from the previous iteration. This was not really addressed before? Is this not a problem? Or was this solved differently?

Edit Ah, this is actually also in the very first proposal. That is what State really is for. To handle the RETURNN prev: logic in a generic way. So basically this solves #6.

So, to recap:

I still see a problem in catching errors here. When the user writes the code ignoring State but just like before, it would compile without error, but it would do the wrong thing, i.e. not the expected behavior. It would always use lstm.initial_state() as state in every iteration in this example.

Can we somehow catch such errors to avoid unexpected behavior?

Or is this maybe not too much a problem as this is actually not too much unexpected?

Also, do we want that we can also skip the state argument, i.e. that it has a reasonable default? Modules might have state=None in the function signature and then internally do sth like if state is None: state = self.initial_state(). However, this code would have exactly the problem as just described. I.e. then it would not use the prev state but always the initial state in every iteration. Is this a fundamental problem which cannot really be solved?

In PyTorch, this is the same behavior though, right? In PyTorch, there is the difference of LSTM (on seq) vs LSTMCell (on a single frame). LSTM does have this default initial state, but LSTMCell does not, as it does not make sense for this case. In RETURNN, we have both together, which maybe causes this confusion. But we do not need to wrap it exactly the same here in returnn-common. We could also have some LSTMCell. Or maybe use RnnCell or RecCell to be able to use other rec units from RETURNN as well (not just NativeLstm2). Or wrap them all separately. Or both. And for all of these, we require to have the state explicit as argument (no default). Although such modules would still have some function initial_state.

albertz commented 3 years ago

I'm questioning now whether this explicit code is maybe too complicated for many of the common use cases. On the RETURNN side, the hidden state is hidden and not explicit. So when translating some old-style RETURNN model to this new way, it would take somewhat extra effort (although it should be straight-forward).

One alternative is to introduce State inside the module call (but not as a module attrib). So a LSTMCell or RecCell could be defined like:

class LstmCell(Module):
  def __init__(...): ...
  def forward(self, input, state=None):
    if state is None:
      state = StateScope(...)
    h, c = state.h, state.c
    ...
    h.assign(new_h)
    c.assign(...)
    return new_h

Then we can use e.g. this code:

lstm = LstmCell(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

This would do param sharing between both lstm calls. However, it would not share the same hidden state (as expected).

The StateScope would be attached to the Loop (via sth like Loop.get_current() -> Loop which we can implement via the with loop: logic).

The not-so-nice thing about this is that we clearly differentiate between state and other input now. So it becomes complicated/non-straightforward when the user would also want to pass some custom state (custom h or c) to lstm. Or how to return the new cell state c. It would maybe look like this:

layer1_c = get_state_scope(layer1).c.get()
Atticus1806 commented 3 years ago

So I just read through anything and I will add my comments. Since it was quite a lot, maybe I also missunderstood something, then just correct me: So first lets start with the concept of a State. From what I remeber we started with adding a logic which is able to handle the prev: logic from Returnn and then started expanding on that adding more "features" like hidden state handling to it.

One of my general questions would be, why we even do explicit recurrent handling like getting state updates and so on instead of just "references" to these updates which would be passed in to Returnn for the config. Isn't the actual handling of how the LSTM unit works something the is part of Returnn and what we try to achieve in this repo is a nice way of writing it down. From what I understand right now you are also looking to include additional concepts.

From this I would also conclude my reasoning for deciding between the two variants for the hidden state: I feel like the implicit handling is one of the biggest strengths of Returnn, even though ofc. its a sharp sword to work with. Not explicitly having to worry about certain details of a layer makes it (at least from my HiWi view) quite more easy to start and also work with. What we should aim for in my opinion is an interface which is as simple as possible to "just start", but has the flexibility of allowing stronger configurations once the user is more used to it. So I feel like it would be fine to accept that basic configurations (where in this caes you don't do modifications to the internal hidden state logic) are as easy as possible but if you want to make use of some stronger concepts you would have to go a bit more in depth. The problem why this in other casees causes troubles is when the documentation is not good enough, making users trying to make the transition into more detail feel lost.

Now onto the specific example: I would prefer the second option in the general case. For the more advanced options I would then include an option to get the StateScope or certain elements of it and also include the possibility to make modifications to it. Again I think this is more of an advanced concept which I am not sure how much it will be used. Maybe I am mistaken here. So what we could allow is doing something like: scope = get_state_scope(layer1) and then do stuff like scope.c.get() like you suggested, but also maybe scope.c.set() or even set_state_scope(layer1, scope) to overwrite the full scope with the (modified) new scope. This would leave the whole construction flexible enough to handle these cases in my opinion without too much of a workarround.

But overall I think this is a point where we need to put some thought into. Maybe we could work our 2 or 3 concrete ways (with 3-4 examples each) and then ask other users about it, because I feel like this is something where User feedback might be meaningful to make a decision. What do you think?

albertz commented 3 years ago

So first lets start with the concept of a State. From what I remeber we started with adding a logic which is able to handle the prev: logic from Returnn and then started expanding on that adding more "features" like hidden state handling to it.

Basically. But not directly. All the discussion here should be seen independent from RETURNN really. But really more about what would be a straightforward design (for people which do not know about RETURNN). It should not be that we adopt some strange thing from RETURNN only because that is how it is now in RETURNN.

So, when we think about loops (for ...: or while ...:), we need some way to access values from the previous iteration. And the question is, how to design that.

The next question is, whether we want to allow hidden state, which can be hidden, and thus is a separate concept, or whether there should not be a separate concept for hidden state, and it would just be the same as other values from previous iteration.

The cleanliness and straightforwardness of the design is of highest priority here. How this maps to RETURNN in the end is only secondary. We can surely map whatever we came up with, as long as it is well defined. Or if not, we can simply extend RETURNN such that we can. Although for almost everything discussed here, I think that RETURNN already supports it, so no extension or modification on RETURNN side would be needed.

One of my general questions would be, why we even do explicit recurrent handling like getting state updates and so on instead of just "references" to these updates which would be passed in to Returnn for the config.

I don't exactly understand. What do you mean by references to the updates?

The argument of explicit vs hidden/implicit is simple: Because only that way, it is straightforward. Hidden/implicit is always problematic. Esp when you want to change it, or have some control over it, it becomes unnatural. As long as you do not want to touch or access the hidden state, it doesn't matter. But as soon as you do, it matters. And there always will be such cases.

Isn't the actual handling of how the LSTM unit works something the is part of Returnn and what we try to achieve in this repo is a nice way of writing it down.

We are not changing that. Here we simply discuss how we design the handling of accessing previous values (values from the prev loop iteration), and hidden state, or whether hidden state should be handled differently or just the same as other previous values.

From what I understand right now you are also looking to include additional concepts.

No. No underlying concept is really new. It would still all map to what RETURNN does right now. Just the API is new. This is the whole point here of returnn-common. And for designing the API, we have the freedom to do it as we want. And I think we should try to prioritize cleanliness and straightforwardness.

From this I would also conclude my reasoning for deciding between the two variants for the hidden state: I feel like the implicit handling is one of the biggest strengths of Returnn, even though ofc. its a sharp sword to work with. Not explicitly having to worry about certain details of a layer makes it (at least from my HiWi view) quite more easy to start and also work with.

Many people claim that PyTorch is easier because it is all explicit. Usually nothing is hidden away. When reading other people's code, you rarely would ask yourself what it would actually do, or whether this module has some hidden state, because it is all explicit.

Explicitness can result in slightly more code but it is usually still pretty simple and short, and it is easier to follow and reason about because you don't have to think about implicit behavior.

Implicit behavior is maybe fine for all the simple cases but once it gets more complex, it can make it really hard to reason about.

I spoke with some other people and they all strictly preferred the explicitness.

What we should aim for in my opinion is an interface which is as simple as possible to "just start", but has the flexibility of allowing stronger configurations once the user is more used to it.

Yes, simplicity and flexibility are both the main goals of RETURNN, and also here of returnn-common.

However, I think you argue exactly for the opposite as I did before.

What does simple mean? Simple does not necessarily means short code. Simple is about writing code, reading code, and understanding code. It should never be ambiguous, otherwise it is not simple. It should be clear and straightforward. Straightforwardness makes it simple to write and understand. Clearness makes it simple to read.

What does flexibility means? It does not just mean that more complex things are possible. More complex things are always possible. Flexibility also means that more complex things are straightforward to do. Otherwise it is actually not really flexible, if something is not straightforward or unclear.

So I feel like it would be fine to accept that basic configurations (where in this caes you don't do modifications to the internal hidden state logic) are as easy as possible but if you want to make use of some stronger concepts you would have to go a bit more in depth. The problem why this in other casees causes troubles is when the documentation is not good enough, making users trying to make the transition into more detail feel lost.

You cannot really compensate a complicated non-straightforward design by just having better documentation. Treating hidden state as something different than non-hidden states just makes it more complicated, and not straightforward. When you have worked with non-hidden state before, it is not clear or straightforward how to work with hidden state now, when this is a different thing or concept.

Now onto the specific example: I would prefer the second option in the general case.

I actually asked someone on what behavior he would expect from this code:

lstm = LstmCell(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

He expected that the two lstm calls would not only share the params but also the hidden state. Which is exactly not what would happen. Or it depends on the implementation of LstmCell. So this is a perfect example what I meant before: It is not easy to read or understand. The behavior of the hidden state is unclear and ambiguous. And it is not straightforward how to handle hidden state now.

But overall I think this is a point where we need to put some thought into. Maybe we could work our 2 or 3 concrete ways (with 3-4 examples each) and then ask other users about it, because I feel like this is something where User feedback might be meaningful to make a decision. What do you think?

Yea, I also thought about getting some more feedback. It's a good idea to prepare some examples.

In all cases, what I think is important:

So the different examples could be:

I agree, some of these examples are maybe a bit exotic. But that is my point. It should still be straightforward to do. Otherwise it is not really flexible. In PyTorch, all of these are very simple and straightforward. In current RETURNN (dict-style net def), while all are possible in principle, only the first three are simple, while the others are definitely not, esp not straightforward. I expect and argue that whenever you have it explicit, it becomes straightforward.

albertz commented 3 years ago

Some further thoughts on the handling of state in general in a loop (orthogonal on the discussion whether hidden state should be a separate concept or not):

While assign and get on such a State object are somewhat canonical, this leads to lots of boilerplate code, which makes it somewhat more complicated to write than the corresponding natural logic in normal Python (or PyTorch) code (in a normal for or while loop).

I'm thinking about the options to simplify that to more canonical simplified Python, while still also not doing too much magic, such that it is still clear what happens.

So, given these options, I tend to prefer StateHolder with __setattr__ and __getattr__.

albertz commented 3 years ago

How would the StateHolder with __setattr__ and __getattr__ look like? Here some possible variations:

The Loop object already could create that, as loop.state. It's only inside the with block then but this is maybe ok.

Should we allow usages without defining the initial value or the shape? Maybe. In that case, the code for 2 LSTM layers, sharing params, not sharing hidden state can look like:

lstm = Lstm(...)
with Loop() as loop:
  layer1, loop.state.layer1 = lstm(loop.unstack(x), loop.state.layer1)
  layer2, loop.state.layer2 = lstm(layer1, loop.state.layer2)
  output = loop.stack(layer2)

Or for 2 LSTM layers, sharing params, sharing hidden state:

lstm = Lstm(...)
with Loop() as loop:
  layer1, loop.state.lstm = lstm(loop.unstack(x), loop.state.lstm)
  layer2, loop.state.lstm = lstm(layer1, loop.state.lstm)
  output = loop.stack(layer2)

Or consider this Python code:

i = 0
for x_ in x:
  i = i + 1

Equivalent code here:

with Loop() as loop:
  loop.unstack(x)
  loop.state.i = loop.state.i + 1

How to explicitly specify the initial value, and maybe other things like shape? Maybe this can just be extended, like so:

with Loop() as loop:
  loop.unstack(x)
  loop.state.i = State(shape=(), initial=0)
  loop.state.i = loop.state.i + 1

This would be maybe a bit counter intuitive as loop.state.i assignments and reads would normally expect or return a LayerRef, and an assignment by State is handled special. But other variants might also look a bit inconsistent, like loop.define_state("i", initial=0) or so. I'm not sure.

JackTemaki commented 3 years ago

The "current" example at the top looks already quite understandable and straightforward, but I have some comments / questions:

albertz commented 3 years ago

The "current" example at the top looks already quite understandable and straightforward, but I have some comments / questions:

  • Either the unstack function or the loop object itself should be able to take some kind of information what axes will be used for the loop. I think in the current example everything was based on the time axis.

Yes right. This is basically the discussion here: https://github.com/rwth-i6/returnn/issues/597, https://github.com/rwth-i6/returnn/issues/632 and related.

We still did not fully clarify whether we maybe should allow some defaults for cases where it is unique. Or basically we anyway need to do that for all existing layers to not break backward compatibility.

But anyway, this is somewhat orthogonal to the discussion here.

  • the loop.last should be able to get an n parameter so that you can get the n last states (just what the window layer inside a recurrent net does now, to implement e.g. causal convolution decoders"

No, I don't think so. It should follow the same principles as everything in RETURNN, it should be as simple as possible, and atomic. You can very easily get this functionality e.g. by putting a causal WindowLayer and then get the loop.last of that (that is anyway how loop.last with n would work internally).

  • ~it is not clear yet to me how masks can/should be used, so if I want to update states only at a certain condition like with a future Condobject, but with a mask tensor matching the "unstacked" axis.~ Sorry, this is Masked computation wrapper #23

Yes, this is #23, but actually, when we have all hidden state also now explicit, i.e. no distinction anymore, this also becomes pretty straight forward even without any such wrapper. The only reason such wrapper can be useful is to allow potential further automatic optimizations (as MaskedComputationLayer does right now).

albertz commented 3 years ago

So, a first version is implemented now.

See the test in test_rec_ff. It uses this code:

x = get_extern_data("data")
with Loop() as loop:
  x_ = loop.unstack(x, axis="T")
  loop.state.h = y_ = Linear(n_out=13)([x_, loop.state.h])
  y = loop.stack(y_)
return y

Which results in this net dict:

{'loop': {'class': 'rec',
          'from': [],
          'unit': {'h': {'class': 'copy', 'from': 'linear'},
                   'linear': {'class': 'linear',
                              'from': ['rec_unstack', 'prev:h'],
                              'n_out': 13},
                   'output': {'class': 'copy', 'from': 'linear'},
                   'rec_unstack': {'axis': 'T',
                                   'class': 'rec_unstack',
                                   'from': 'base:data:data'}}},
 'output': {'class': 'copy', 'from': 'loop/output'}}

(Sorry for the pprint formatting...) (Some of the layer names will probably change in some future version.)

albertz commented 3 years ago

So I'm closing this now, as we have the initial design implemented.

Please open separate issues if sth is broken, missing, or whatever.

albertz commented 3 years ago

Just for reference, also Loop.end has been implemented now.

albertz commented 3 years ago

What's still missing is the default interface for all wrapped RETURNN layers with hidden state, which should make the state more explicit, as discussed here. This is #31.