Open wise-east opened 6 years ago
I wanted to PR this change but I need to sign some agreement to do this and I am not sure about the details there so I'll just post the changes since they are small: in qrnn.py add the following class:
class BiDirQRNNLayer(nn.Module):
def __init__(self, input_size, hidden_size=None, save_prev_x=False, zoneout=0, window=1, output_gate=True,
use_cuda=True):
super(BiDirQRNNLayer, self).__init__()
assert window in [1,
2], "This QRNN implementation currently only handles convolutional window of size 1 or size 2"
self.window = window
self.input_size = input_size
self.hidden_size = hidden_size if hidden_size else input_size
self.zoneout = zoneout
self.save_prev_x = save_prev_x
self.prevX = None
self.output_gate = output_gate
self.use_cuda = use_cuda
self.forward_qrnn = QRNNLayer(input_size, hidden_size=hidden_size, save_prev_x=save_prev_x, zoneout=zoneout, window=window,
output_gate=output_gate, use_cuda=use_cuda)
self.backward_qrnn = QRNNLayer(input_size, hidden_size=hidden_size, save_prev_x=save_prev_x, zoneout=zoneout, window=window,
output_gate=output_gate, use_cuda=use_cuda)
def forward(self, X, hidden=None):
if not hidden is None:
fwd, h_fwd = self.forward_qrnn(X, hidden=hidden)
bwd, h_bwd = self.backward_qrnn(torch.flip(X, [0]), hidden=hidden)
else:
fwd, h_fwd = self.forward_qrnn(X)
bwd, h_bwd = self.backward_qrnn(torch.flip(X, [0]))
bwd = torch.flip(bwd, [0])
return torch.cat([fwd, bwd], dim=-1), torch.cat([h_fwd, h_bwd], dim=-1)
in the same file in the "class QRNN(torch.nn.Module):" replace :
self.layers = torch.nn.ModuleList(
layers if layers else [QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, **kwargs) for l in
range(num_layers)])
with :
if bidirectional:
self.layers = torch.nn.ModuleList(
layers if layers else [BiDirQRNNLayer(input_size if l == 0 else hidden_size*2, hidden_size, **kwargs) for l in
range(num_layers)])
else:
self.layers = torch.nn.ModuleList(
layers if layers else [QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, **kwargs) for l in
range(num_layers)])
and remove the assert statement above that deals with bidirectional.:
assert bidirectional == False, 'Bidirectional QRNN is not yet supported'
I didn't test super thoroughly but it works for me on a basic use case. If you need to fix something please post it here also :) @salesforce do feel free to incorporate this into your code. I don't think any paper work is needed for this.
When is the bidirectional QRNN going to be ready? It says it would be available in the near future, but I guess I underestimated the span of time represented by 'near'. I'm wondering if it is being developed at all.