suragnair / seqGAN

A simplified PyTorch implementation of "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." (Yu, Lantao, et al.)
642 stars 149 forks source link

oracle sampler during training discriminator (train on own dataset) #22

Closed TerenceChen95 closed 2 years ago

TerenceChen95 commented 2 years ago

Hi suragnair!

Really appreciate you work! Here I have a question after looking through your project and issues. My question is When training on my own dataset, how can I pass a real-world-data-sampler to the function below? def train_discriminator(discriminator, dis_opt, real_data_samples, generator, oracle, d_steps, epochs): pos_val = oracle.sample(100) since I don't have one, should I use my own data to pretrain such a sampler in advance? But it really matters if I can get a good representation of the distribution of my own dataset, isn't it?

TerenceChen95 commented 2 years ago

another solution that comes into my mind is to randomly pick encoded seqs from the dataset.