williamSYSU / TextGAN-PyTorch

TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models.
MIT License
889 stars 205 forks source link

sampling data from the generator #30

Open dishavarshney082 opened 4 years ago

dishavarshney082 commented 4 years ago

I could not understand the way you are sampling data from the generator. Basically, how are you creating batches from num_samples?

def sample(self, num_samples, batch_size, start_letter=cfg.start_letter):
    """
    Samples the network and returns num_samples samples of length max_seq_len.
    :return samples: num_samples * max_seq_length (a sampled sequence in each row)
    """
    num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1
    samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long()

    # Generate sentences with multinomial sampling strategy
    for b in range(num_batch):
        hidden = self.init_hidden(batch_size)
        inp = torch.LongTensor([start_letter] * batch_size)
        if self.gpu:
            inp = inp.cuda()

        for i in range(self.max_seq_len):
            out, hidden = self.forward(inp, hidden, need_hidden=True)  # out: batch_size * vocab_size
            next_token = torch.multinomial(torch.exp(out), 1)  # batch_size * 1 (sampling from each row)
            samples[b * batch_size:(b + 1) * batch_size, i] = next_token.view(-1)
            inp = next_token.view(-1)
    samples = samples[:num_samples]

    return samples
williamSYSU commented 4 years ago

sample() is to sample num_samples sentences from the generator G. The CUDA memory would be overflowed if we directly sample num_samples sentences from G. Thus, num_samples sentences are divided into num_batch batches of sentences.

dishavarshney082 commented 4 years ago

Does sampling from the generator means the same as inferencing from the pretrained MLE model? In RelGan the model used for pretraining and for generating samples are different?