Open urchade opened 2 years ago
Hi, I'm also in need of an autoregressive model with end_class
. Here's my approach https://github.com/CarlossShi/pytorch-struct/commit/b5a56e8aee2742586b4e432fafd2c6b7be63273c. I use the variable active
to record whether the sequences have ever output end_class
or not. If there is no sequence alive, break the for loop to save time. I am not quite familiar with NLP, so I am not sure if this is common practice. In addition, some problems remain to be solved:
_beam_search
method does not work as expected (see example below).Following the official documentation Autoregressive / Beam Search, I made some examples.
import torch
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, 'pytorch-struct')
import torch_struct
batch, layer, H, C, N, K = 3, 1, 5, 4, 10, 2 # K: sample shape
init = (torch.rand(batch, layer, H),
torch.rand(batch, layer, H))
def t(a):
return [t.transpose(0, 1) for t in a]
def show_ar(chain):
plt.imshow(chain.detach().transpose(0, 1))
class RNN_AR(torch.nn.Module):
def __init__(self, sparse=True):
super().__init__()
self.sparse = sparse
self.rnn = torch.nn.RNN(H, H, batch_first=True)
self.proj = torch.nn.Linear(H, C)
if sparse:
self.embed = torch.nn.Embedding(C, H)
else:
self.embed = torch.nn.Linear(C, H)
def forward(self, inputs, state):
"""
@param inputs: {Tensor: (batch, 1)}
@param state: e.g. ({Tensor: (batch, layer, H)}, {Tensor: (batch, layer, H)})
@return: {Tensor: (batch, layer, C)}, [{Tensor: (batch, layer, H)}]
"""
if not self.sparse and inputs.dim() == 2:
inputs = torch.nn.functional.one_hot(inputs, C).float()
inputs = self.embed(inputs) # {Tensor: (batch, 1, H)}
out, state = self.rnn(inputs, t(state)[0]) # out: {Tensor: (batch, layer, H)}, t(state)[0] & state: {Tensor: (layer, batch, H)}
out = self.proj(out) # {Tensor: (batch, layer, C)}
return out, t((state,)) # t((state,))[0]: {Tensor: (batch, layer, H)}
dist = torch_struct.Autoregressive(RNN_AR(), init, C, N, end_class=1)
path, scores, logits = dist.greedy_max() # path, logits: {Tensor: (batch, N, C)}, scores: {Tensor: (batch,)}
for b in range(batch):
plt.subplot(1, batch, b + 1)
plt.axis('off')
show_ar(path[b])
plt.suptitle('dist.greedy_max()')
plt.show()
out = dist.sample(torch.Size([K])) # {Tensor: (K, batch, N, C)}
for k in range(K):
for b in range(batch):
plt.subplot(K, batch, batch * k + b + 1)
plt.axis('off')
show_ar(out[k, b])
plt.suptitle('dist.sample(torch.Size([K]))')
plt.show()
out = dist.beam_topk(K) # {Tensor: (K, batch, N, C)}, first output of _beam_search
for k in range(K):
for b in range(batch):
plt.subplot(K, batch, batch * k + b + 1)
plt.axis('off')
show_ar(out[k, b])
plt.suptitle('dist.beam_topk(K)')
plt.show()
The output images are as follows.
In the example above, end_class
is set to 1. I expect that if all setences meet the end_class
(i.e. there is a yellow square in the second row of each array), then the remaining columns are truncated. It seems that the sample
method works expected, but the _beam_search
not. I'm not quite familiar with the beam search function, so I just get stuck here.
Hope that help and any further support would be greatly appreciated.
end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49