Closed bcserna closed 6 years ago
As you can see in train.py
, x
is created with the following shape:
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
So, for x
, before the reshaping, the shape is (batch, n_sequence, n_tokens, seq_or_pos)
:
batch
will be used to access a specific element in a batch,n_sequence
will be used to access a specific sequence of an instance, it is only useful for tasks that take multiple sentences as input (multiple choice problems for example),n_ctx
will be used to access a specific token of a sequence,seq_or_pos
can either take 0 or 1 and is used to tell whether the sequence represents token or positions indices in the embedding matrix. Again from train.py
, we can see that we will fill x[..., 0]
with the tokens indices and xmb[..., 1]
with the position indices.
xmb[i, 0, :l12, 0] = x12
xmb[i, 1, :l13, 0] = x13
[...]
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
Indeed, in this version of the transformer network, the positional embedding are learned so positions in the input sequence are just considered like normal tokens and have a corresponding embedding in the embedding matrix. You can see that the position embeddings are located at the end of the embedding table (starting at index n_vocab + n_special
).
So, if we analyse the forward
method line by line we get:
x = x.view(-1, x.size(-2), x.size(-1))
We first flatten (remove) the n_sequence
dimension as the inference on each sequence of and input is independent. We get a tensor of shape (n_batch * n_sequence, n_tokens, seq_or_pos)
e = self.embed(x)
We fetch the embeddings for the tokens AND the positions at the same time so we get a tensor of dimension (n_batch * n_sequence, n_tokens, seq_or_pos, dim_emb)
with dim_emb
being the dimension of the embedding vectors (here 768).
h = e.sum(dim=2)
Then, as described in the research paper, we simply add the token embeddings with their corresponding position embedding. We get a tensor of shape (n_batch * n_sequence, n_tokens, dim_emb)
that will be the input to our transformer blocks Block
.
The rest of the function is just the application of the blocks to the input so I won't detail it.
Great explanation, it's clear now, thank you!
So there's the
TransformerModel
'sforward
method, and I just can't get a hold of the position embedding part (and might be wrong about others). So, as far as I can tell, step-by-step it goes like this:[ ? x sequences (?) x tokens (512) ]
[ ? x sequences (?) x tokens (512) x emb_dim (768) ]
[ ? x sequences x emb_dim (768) ]
[ sequences x tokens (512) x emb_dim (768) ]
here?My questions are:
x
,e
, andh
tensors' axes?Thank you in advance!