facebookarchive / torch-rnnlib

This library provides utilities for creating and manipulating RNNs to model sequential data.
Other
192 stars 32 forks source link

Attention Model #4

Open D-X-Y opened 7 years ago

D-X-Y commented 7 years ago

If I want to use the attention rnn with this lib.

Is there anything examples or guides for building a attention rnn?

jgehring commented 7 years ago

There's no concrete example, but one way to do it would be to extend one of the cells (e.g. the LSTM one) to run the hidden state through the attention model. For example, let's say you want to provide attention model output alongside the usual input to the LSTM. You could do something like this (not tested) and then use it in a nn.SequenceTable as described in the Readme.

local rnnlib = require 'rnnlib'

function LSTM_Att(nin, nhid, attn)
    -- "attn" is an nn module that produces 'nhid'-sized output
    local omake, oinit = rnnlib.cell.LSTM(nin + nhid, nhid)

    local make = function(prevch, input)
        -- Here, "input" should be a table containing the original input
        ---and the full attention model input at every time step
        local oinput, ain = input:split(2)
        local prevc, prevh = prevch:split(2)

        -- apply attention model on its input and the previous hidden state
        local aout = attn({ain, prevh})

        -- concatenate attention output to input and apply LSTM computation
        -- as usual
        return omake(
            nn.Identity()({prevc, prevh}),
            nn.JoinTable()({oinput, aout})
        )
    end

    return make, oinit
end

Similarly, you could apply the attention model after the LSTM cell computation or even change the cell computation itself.