harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

end_class support for Autoregressive #125

Open urchade opened 2 years ago

urchade commented 2 years ago

end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49

CarlossShi commented 2 years ago

Hi, I'm also in need of an autoregressive model with end_class. Here's my approach https://github.com/CarlossShi/pytorch-struct/commit/b5a56e8aee2742586b4e432fafd2c6b7be63273c. I use the variable active to record whether the sequences have ever output end_class or not. If there is no sequence alive, break the for loop to save time. I am not quite familiar with NLP, so I am not sure if this is common practice. In addition, some problems remain to be solved:

Following the official documentation Autoregressive / Beam Search, I made some examples.

import torch
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, 'pytorch-struct')
import torch_struct
batch, layer, H, C, N, K = 3, 1, 5, 4, 10, 2  # K: sample shape
init = (torch.rand(batch, layer, H),
        torch.rand(batch, layer, H))

def t(a):
    return [t.transpose(0, 1) for t in a]

def show_ar(chain):
     plt.imshow(chain.detach().transpose(0, 1))

class RNN_AR(torch.nn.Module):
    def __init__(self, sparse=True):
        super().__init__()
        self.sparse = sparse
        self.rnn = torch.nn.RNN(H, H, batch_first=True)
        self.proj = torch.nn.Linear(H, C)
        if sparse:
            self.embed = torch.nn.Embedding(C, H)
        else:
            self.embed = torch.nn.Linear(C, H)

    def forward(self, inputs, state):
        """

        @param inputs: {Tensor: (batch, 1)}
        @param state:  e.g. ({Tensor: (batch, layer, H)}, {Tensor: (batch, layer, H)})
        @return: {Tensor: (batch, layer, C)}, [{Tensor: (batch, layer, H)}]
        """
        if not self.sparse and inputs.dim() == 2:
            inputs = torch.nn.functional.one_hot(inputs, C).float()
        inputs = self.embed(inputs)  # {Tensor: (batch, 1, H)}
        out, state = self.rnn(inputs, t(state)[0])  # out: {Tensor: (batch, layer, H)}, t(state)[0] & state: {Tensor: (layer, batch, H)}
        out = self.proj(out)  # {Tensor: (batch, layer, C)}
        return out, t((state,))  # t((state,))[0]: {Tensor: (batch, layer, H)}

dist = torch_struct.Autoregressive(RNN_AR(), init, C, N, end_class=1)

path, scores, logits = dist.greedy_max()  # path, logits: {Tensor: (batch, N, C)}, scores: {Tensor: (batch,)}
for b in range(batch):
    plt.subplot(1, batch, b + 1)
    plt.axis('off')
    show_ar(path[b])
plt.suptitle('dist.greedy_max()')
plt.show()

out = dist.sample(torch.Size([K]))  # {Tensor: (K, batch, N, C)}
for k in range(K):
    for b in range(batch):
        plt.subplot(K, batch, batch * k + b + 1)
        plt.axis('off')
        show_ar(out[k, b])
plt.suptitle('dist.sample(torch.Size([K]))')
plt.show()

out = dist.beam_topk(K)  # {Tensor: (K, batch, N, C)}, first output of _beam_search
for k in range(K):
    for b in range(batch):
        plt.subplot(K, batch, batch * k + b + 1)
        plt.axis('off')
        show_ar(out[k, b])
plt.suptitle('dist.beam_topk(K)')
plt.show()

The output images are as follows.

dist greedy_max() dist sample(torch Size( K )) dist beam_topk(K)

In the example above, end_class is set to 1. I expect that if all setences meet the end_class (i.e. there is a yellow square in the second row of each array), then the remaining columns are truncated. It seems that the sample method works expected, but the _beam_search not. I'm not quite familiar with the beam search function, so I just get stuck here.

Hope that help and any further support would be greatly appreciated.