spro / practical-pytorch

Go to https://github.com/pytorch/tutorials - this repo is deprecated and no longer maintained
MIT License
4.51k stars 1.1k forks source link

Complexity of attention score calculation in seq2seq tutorial #107

Open pravn opened 6 years ago

pravn commented 6 years ago

I am wondering if the attention score calculation could be vectorized to run per batch instead of running two loops (batch size x encoder time steps)

The sections of code are here:

class Attn(nn.Module):
....

   def forward:
        # For each batch of encoder outputs
        for b in range(this_batch_size):
            # Calculate energy for each encoder output
            for i in range(max_len):
                attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))

Instead, I think we can eliminate the outer loop so that it only runs for encoder time steps: Reshape everything so that it has batch size as first dimension, then run the score calculation.

attn_energies = Variable(torch.zeros(max_len,this_batch_size)
hidden = hidden.transpose(0,1).squeeze(1)

for i in range(max_len):
    attn_energies[i] = self.score(hidden, encoder_outputs[i])
aevilorz commented 6 years ago

I also have 2 questions in these attention calcution codes. BTW, i use pytorch 0.4.

1. Why using Variable(torch.zeros(max_len,this_batch_size) to store attn_energies ?

When i first read this line, i doubted the correctness of the gradients backward process. After testing, i found it correct. But .backward() only can be called at the first time. At the second time, it raised

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

So i think attn_energies could be stored in this way:

attn_energies_lst = []

for i in range(?):
  attn_energy_i = self.score(hidden[?], encoder_outputs[?)
  attn_energies_lst.append(attn_energy_i)

attn_energies = torch.stack(attn_energies_lst)

2. Much similar to what @pravn mentioned, can we calculate attention weight in matrix form?

By combining .transpose(), .expand(), .matmul() and .cat(), the attn_energies can be calculated without any loop:

first reshape hidden as shape (B, N, 1) and encoder_outputs as (B, S, N ); then for dot and general attention:

attn_energies = torch.matmul(encoder_outputs, hidden)  # shape of (B, S, 1)
# .squeeze()

for concat attention:

hidden_expand = hidden.expand(-1, -1, S).transpose(1, 2)  # shape of (B, S, N)
enc_cat_hid = torch.cat([encoder_outputs, hidden_expand], dim=-1)  # shape of (B, S, 2*N)
# After nn.Linear(2*N, N), enc_cat_hid with shape (B, S, N)
# v is shape of (N)
attn_energies = torch.matmul(enc_cat_hid , v)  # shape of (B, S)