suragnair / seqGAN

A simplified PyTorch implementation of "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." (Yu, Lantao, et al.)
642 stars 149 forks source link

Should the oracle sampling based on previous generated samples? #15

Closed alibabadoufu closed 3 years ago

alibabadoufu commented 4 years ago

https://github.com/suragnair/seqGAN/blob/ae8ffcd54977bd9ee177994c751f86d34f5f7aa3/helpers.py#L80

Hi, thanks so much for this code. It is really simple and easy to understand. There is one particular question I am trying to understand: when you are training with MLE loss, oracle sampling is always based on the 'START LETTER'.

def sample(self, num_samples, start_letter=0):
        """
        Samples the network and returns num_samples samples of length max_seq_len.

        Outputs: samples, hidden
            - samples: num_samples x max_seq_length (a sampled sequence in each row)
        """

        samples = torch.zeros(num_samples, self.max_seq_len).type(torch.LongTensor)

        h = self.init_hidden(num_samples)
        inp = autograd.Variable(torch.LongTensor([start_letter]*num_samples))

        if self.gpu:
            samples = samples.cuda()
            inp = inp.cuda()

        for i in range(self.max_seq_len):
            out, h = self.forward(inp, h)               # out: num_samples x vocab_size
            out = torch.multinomial(torch.exp(out), 1)  # num_samples x 1 (sampling from each row)
            samples[:, i] = out.view(-1).data

            inp = out.view(-1)

        return samples

For example the code above, should this out, h = self.forward(inp, h) be out, h = self.forward(samples[:,i-1], h)

suragnair commented 4 years ago

Hi, I am not sure I completely understand the question. The next input is set to be the previous output in the line inp = out.view(-1) which should be equivalent to samples[:,i-1]?