linkedin / dagli

Framework for defining machine learning models, including feature generation and transformations, as directed acyclic graphs (DAGs).
BSD 2-Clause "Simplified" License
354 stars 40 forks source link

Stacking LSTMs #13

Closed cyberbeat closed 2 years ago

cyberbeat commented 2 years ago

This is the code for LSTM stacking:

 public NNLSTMLayer stack(int... unitCounts) {
    Arguments.check(unitCounts.length > 0, "At least one unit count for at least one new layer must be provided");
    NNLSTMLayer previous = this.withUnitCount(unitCounts[0]).withInput(this.getInputLayer());

    for (int i = 1; i < unitCounts.length; i++) {
      previous = previous.withInput(previous).withUnitCount(unitCounts[i]);
    }

    return previous;
  }

I wonder, why the first stacked layer does not get this as input, but this.getInputLayer()?

jeffpasternack commented 2 years ago

The idea is that stack(...) returns a stack/sequence of layers "like" this, but the returned stack (as you point out) doesn't actually include this. The nice thing about this is that you can thus conveniently define the sizes of all the layers with a single passed array (otherwise you'd have to define the size of the first layer separately), although I agree the semantics aren't obvious from the function name alone (perhaps something like createStack(...) would be more clear?)

cyberbeat commented 2 years ago

Oh, thanks, it would be good to clarify this in the javadoc, the name is ok.