williamSYSU / TextGAN-PyTorch

TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models.
MIT License
892 stars 205 forks source link

Doesn't work with torch 1.3.1 #18

Closed egglang closed 4 years ago

egglang commented 4 years ago

Thank you for developing great software.

At the present, the newest stable version of torch is 1.3.1. With the 1.3.1, an error occurs in dataloader.py when trying to run some script in run directory. Works well with torch version 1.2.

$ python run_seqgan.py 0 0
job_id: 0, gpu_id: 0
====================================================================================================
> training arguments:
>>> if_test: 0
>>> run_model: seqgan
>>> k_label: 2
>>> dataset: oracle
>>> model_type: vanilla
>>> loss_type: rsgan
>>> if_real_data: 0
>>> cuda: 1
>>> device: 0
>>> shuffle: 0
>>> gen_init: normal
>>> dis_init: uniform
>>> samples_num: 10000
>>> vocab_size: 5000
>>> mle_epoch: 120
>>> clas_pre_epoch: 10
>>> adv_epoch: 200
>>> inter_epoch: 15
>>> batch_size: 64
>>> max_seq_len: 20
>>> start_letter: 1
>>> padding_idx: 0
>>> gen_lr: 0.01
>>> gen_adv_lr: 0.0001
>>> dis_lr: 0.0001
>>> clip_norm: 5.0
>>> pre_log_step: 10
>>> adv_log_step: 1
>>> train_data: dataset/oracle.txt
>>> test_data: dataset/testdata/oracle_test.txt
>>> temp_adpt: exp
>>> temperature: 1
>>> ora_pretrain: 1
>>> gen_pretrain: 0
>>> dis_pretrain: 0
>>> adv_g_step: 1
>>> rollout_num: 16
>>> gen_embed_dim: 32
>>> gen_hidden_dim: 32
>>> goal_size: 16
>>> step_size: 4
>>> mem_slots: 1
>>> num_heads: 2
>>> head_size: 256
>>> d_step: 5
>>> d_epoch: 3
>>> adv_d_step: 4
>>> adv_d_epoch: 2
>>> dis_embed_dim: 64
>>> dis_hidden_dim: 64
>>> num_rep: 64
>>> log_file: log/log_1204_0925_58.txt
>>> save_root: save/20191204/oracle/seqgan_vanilla_lt-rsgan_sl20_temp1_T1204_0925_58/
>>> signal_file: run_signal.txt
>>> tips: SeqGAN experiments
====================================================================================================
Creating Oracle...
NLL_Oracle Groud Truth: 5.9067
Starting Generator MLE Training...
Traceback (most recent call last):
  File "main.py", line 127, in <module>
    inst._run()
  File "/content/textgan/instructor/oracle_data/seqgan_instructor.py", line 52, in _run
    self.pretrain_generator(cfg.MLE_train_epoch)
  File "/content/textgan/instructor/oracle_data/seqgan_instructor.py", line 101, in pretrain_generator
    '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True)))
  File "/content/textgan/instructor/oracle_data/instructor.py", line 180, in cal_metrics
    self.gen_data.reset(self.gen.sample(cfg.samples_num, 4 * cfg.batch_size))
  File "/content/textgan/utils/data_loader.py", line 64, in reset
    self.loader.dataset = data
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 270, in __setattr__
    'initialized'.format(attr, self.__class__.__name__))
ValueError: dataset attribute should not be set after DataLoader is initialized

Thanks,

williamSYSU commented 4 years ago

Thanks for your feedback! Now the codes can work with PyTorch 1.3.1.

egglang commented 4 years ago

I really appreciate! It works perfectly!