Closed ethancaballero closed 7 years ago
Yeah, we could do this. The paper is good about details.
Here's a reference implementation: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
Is that an official one?
On Fri, Jun 16, 2017 at 2:59 PM, Ethan Caballero notifications@github.com wrote:
Here's a reference implementation: https://github.com/tensorflow/tensor2tensor/blob/ 414cee3e216947b017a6f7535e5c8328b8ab95c2/tensor2tensor/ models/transformer.py https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_tensorflow_tensor2tensor_blob_414cee3e216947b017a6f7535e5c8328b8ab95c2_tensor2tensor_models_transformer.py&d=DwMCaQ&c=WO-RGvefibhHBZq3fL85hQ&r=wnHFZ7D4m-9MRwk-CWlvCGbWEiQX_AvUO2LuMy4Vj7c&m=gGXlYcJVxo63fyoxqHgRmlvwS_7HQV00fjXkRTl3i1Q&s=aV1k0YK79eUhvvus-3NSqK2OtwLFDMExFExTsuZE18Y&e=
— You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_OpenNMT_OpenNMT-2Dpy_issues_65-23issuecomment-2D309108198&d=DwMCaQ&c=WO-RGvefibhHBZq3fL85hQ&r=wnHFZ7D4m-9MRwk-CWlvCGbWEiQX_AvUO2LuMy4Vj7c&m=gGXlYcJVxo63fyoxqHgRmlvwS_7HQV00fjXkRTl3i1Q&s=WhwCKtyu_fOn7uIq6mgUMEiQMXbWBsZhUp-333cEOtc&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AACMKk8-5FfUNbMY9HhG9hfyY6f3r3ASgoks5sEtCcgaJpZM4N4Dx4&d=DwMCaQ&c=WO-RGvefibhHBZq3fL85hQ&r=wnHFZ7D4m-9MRwk-CWlvCGbWEiQX_AvUO2LuMy4Vj7c&m=gGXlYcJVxo63fyoxqHgRmlvwS_7HQV00fjXkRTl3i1Q&s=b0YzGOUP_HBTsZlz9GR6XQ4a3-ag4ABr5spmVDLqlmg&e= .
Yes, it was merged by Lukasz Kaiser: https://github.com/tensorflow/tensor2tensor/pull/1
there are at least two github repositories implementing transformer in pytorch: eladhoffer/seq2seq.pytorch and jadore801120/attention-is-all-you-need-pytorch that can be used as starting point and both are MIT licensed.
we are working with @jadore801120, hopefully we will have something soon. Don't assume it will magically solve all your problems though.
This is now in, will post details on how to activate soon.
tried to train on CPU, looks like there're some issue (i could be wrong). i left some comments on the commit details.
currently the training process is still running with the following arguments:
python train.py -data data/demo.train.pt -save_model demo-model -encoder_layer transformer -decoder_layer transformer -rnn_size 512 -word_vec_size 512
and here's my patch to make it works.
diff --git a/onmt/modules/StructuredAttention.py b/onmt/modules/StructuredAttention.py
index 49c56bb..8d897d5 100644
--- a/onmt/modules/StructuredAttention.py
+++ b/onmt/modules/StructuredAttention.py
@@ -17,8 +17,11 @@ class MatrixTree(nn.Module):
laplacian = input.exp() + self.eps
output = input.clone()
for b in range(input.size(0)):
- lap = laplacian[b].masked_fill(
- Variable(torch.eye(input.size(1)).cuda().ne(0)), 0)
+ if torch.cuda.is_available():
+ var_input = torch.eye(input.size(1)).cuda().ne(0)
+ else:
+ var_input = torch.eye(input.size(1)).ne(0)
+ lap = laplacian[b].masked_fill(Variable(var_input), 0)
lap = -lap + torch.diag(lap.sum(0))
# store roots on diagonal
lap[0] = input[b].diag().exp()
@@ -39,6 +42,8 @@ class MatrixTree(nn.Module):
if __name__ == "__main__":
dtree = MatrixTree()
- q = torch.rand(1, 5, 5).cuda()
+ q = torch.rand(1, 5, 5)
+ if torch.cuda.is_available():
+ q.cuda()
marg = dtree.forward(Variable(q))
print(marg.sum(1))
diff --git a/onmt/modules/Transformer.py b/onmt/modules/Transformer.py
index f7e189e..8fea4cf 100644
--- a/onmt/modules/Transformer.py
+++ b/onmt/modules/Transformer.py
@@ -3,6 +3,7 @@ Implementation of "Attention is All You Need"
"""
import torch
+import torch.cuda
import torch.nn as nn
import numpy as np
import onmt.modules
@@ -76,7 +77,9 @@ class TransformerDecoder(nn.Module):
d_inner,
opt.dropout)
self.dropout = opt.dropout
- self.mask = get_attn_subsequent_mask(5000).cuda()
+ self.mask = get_attn_subsequent_mask(5000)
+ if torch.cuda.is_available():
+ self.mask.cuda()
def forward(self, input, context, src_words, tgt_words):
"""
diff --git a/onmt/modules/Util.py b/onmt/modules/Util.py
index e4444c4..3d349f3 100644
--- a/onmt/modules/Util.py
+++ b/onmt/modules/Util.py
@@ -35,8 +35,8 @@ class LayerNorm(nn.Module):
def forward(self, z):
if z.size(1) == 1:
return z
- mu = torch.mean(z, dim=1).unsqueeze(1)
- sigma = torch.std(z, dim=1).unsqueeze(1)
+ mu = torch.mean(z, dim=1)
+ sigma = torch.std(z, dim=1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
ln_out = ln_out * self.a_2.expand_as(ln_out) \
+ self.b_2.expand_as(ln_out)
Oh, yes, thank you. I think I got all of these in my next PR. I will test on CPU.
Closing this as an issue for now. More work needs to be done validating this, but the code is there.
Are there plans to add a PyTorch implementation of "the Transformer" from "Attention Is All You Need" https://arxiv.org/abs/1706.03762 ?