asappresearch / sru

Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)
MIT License
2.1k stars 306 forks source link

About k value in SRUCell , why it's can be 4 when n_in != out_size #48

Closed kzjeef closed 6 years ago

kzjeef commented 6 years ago

Hi @taolei87 ,

I have a question about weight matrix dimension, In the SRUCell code, I found the k = 4 if n_in != out_size else 3 But When I read the paper, it's only have 3 weight matrix, W, Wf, Wr,

And I found the n_in will not equal to out_size when the layer number is 0, but I don't understand why k = 4, what's those weight other than W, Wf, Wr ?

below is init code:

class SRUCell(nn.Module):
    def __init__(self, n_in, n_out, dropout=0, vari_dropout=0,
                 use_tanh=1, bidirectional=False):
 ....
        out_size = n_out*2 if bidirectional else n_out
        k = 4 if n_in != out_size else 3
        self.size_per_dir = n_out*k
        self.weight = nn.Parameter(torch.Tensor(
            n_in,
            self.size_per_dir*2 if bidirectional else self.size_per_dir
        ))
  ...

below is when in_in is not equal to out_size:

class SRU(nn.Module):
    def __init__(self, input_size, hidden_size,
                 num_layers=2, dropout=0, vari_dropout=0,
                 use_tanh=1, bidirectional=False):
     ...
    ...
        self.n_in = input_size
        self.n_out = hidden_size
       ...
        self.out_size = hidden_size*2 if bidirectional else hidden_size

        for i in range(num_layers):
            l = SRUCell(n_in=self.n_in if i == 0 else self.out_size,
                        n_out=self.n_out,
                        dropout=dropout if i+1 != num_layers else 0,
                        vari_dropout=vari_dropout,
                        use_tanh=use_tanh,
                        bidirectional=bidirectional)
            self.rnn_lst.append(l)

Thanks

taolei87 commented 6 years ago

Hi,

k is the number of matrices and matrix multiplications in one cell. Normally k=3 as described in the paper. But when input and output size don't match, the input is multiplied by an additional W in the highway connection to change the dimension.

See #12 and #16 for the details about k.

Similar to nn.LSTM and nn.GRU, there will be two sub-RNN modules for two directions when bidirectional=True. The output of each sub-RNN is then concatenated into the final output, and hence the actual output dimension is n_out*2. With the highway connection, this becomes:

output_final = gate concat( output_fwd, output_bwd ) + (1-gate) x

kzjeef commented 6 years ago

Thanks for your clear answer!