Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
941 stars 313 forks source link

Proper way to share weights for `nn.BiSequencer`? #169

Closed willfrey closed 8 years ago

willfrey commented 8 years ago

In Deep Speech 2, they share the input-hidden weights for both directions of a bidirectional RNN. It's mentioned in the middle of page 5.

What is the proper way to do this for various architectures such as nn.LSTM, nn.FastLSTM, nn.GRU, or any nn.Recurrence?

For any nn.Recurrence instance, I think that I can do this:

    inputsize = 5 -- dummy value
    hiddensize = 6 -- dummy value

    i2h = nn.Linear(inputsize, hiddensize)

    fwd_rm = nn.Sequential()
        :add(nn.ParallelTable() -- input is {x[t], h[t-1]}
            :add(i2h)
            :add(nn.Linear(hiddensize, hiddensize)))
        :add(nn.CAddTable()) -- merge
        :add(nn.Sigmoid()) -- transfer

    bwd_rm = nn.Sequential()
        :add(nn.ParallelTable() -- input is {x[t], h[t-1]}
            :add(i2h:sharedClone())
            :add(nn.Linear(hiddensize, hiddensize)))
        :add(nn.CAddTable()) -- merge
        :add(nn.Sigmoid()) -- transfer

    fwd = nn.Recurrence(fwd_rm, hiddensize, 1)
    bwd = nn.Recurrence(bwd_rm, hiddensize, 1)

    brnn = nn.BiSequencer(fwd, bwd, nn.CAddTable())

But there is probably a more elegant way to do this, perhaps by only initializing a fwd RNN and using bwd = fwd:clone(); bwd:reset(); and sharing the weights somehow then.

For nn.LSTM, nn.FastLSTM and nn.GRU, I don't have the faintest idea.

I'm still very much a Torch novice, so any help is appreciated!

Thanks.

willfrey commented 8 years ago

Is it as simple as using narrowing down to the modules I want to share through a series of fwd:get() and bwd:get() calls, then using bwd:get(...):share(fwd:get(...), 'weight', 'bias', 'gradWeight', 'gradBias')?

That appears to be working. My biggest concern is that I'll break something to do with the nn.AbstractRecurrent definition.

nicholas-leonard commented 8 years ago

@willfrey Your first above example should work. What you mention in second comment should also work.

willfrey commented 8 years ago

Thank you!