williamSYSU / TextGAN-PyTorch

TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models.
MIT License
881 stars 206 forks source link

GPU 11G OOM while run run_seqgan.py , how to avoid it by set dowm some parameter #7

Closed SeekPoint closed 5 years ago

SeekPoint commented 5 years ago

ub16c9@ub16c9-gpu:/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/TextGAN-PyTorch/run$ python3.6 run_seqgan.py 0 0 job_id: 0, gpu_id: 0

training arguments:

if_test: 0 run_model: seqgan dataset: oracle model_type: vanilla loss_type: JS if_real_data: 0 cuda: 1 device: 0 shuffle: 0 use_truncated_normal: 0 samples_num: 10000 vocab_size: 5000 mle_epoch: 120 adv_epoch: 200 inter_epoch: 10 batch_size: 64 max_seq_len: 20 start_letter: 1 padding_idx: 0 gen_lr: 0.01 gen_adv_lr: 0.0001 dis_lr: 0.0001 clip_norm: 5.0 pre_log_step: 5 adv_log_step: 1 train_data: dataset/oracle.txt test_data: dataset/testdata/oracle_test.txt temp_adpt: exp temperature: 1 ora_pretrain: 1 gen_pretrain: 0 dis_pretrain: 0 adv_g_step: 1 rollout_num: 32 gen_embed_dim: 32 gen_hidden_dim: 32 goal_size: 16 step_size: 4 mem_slots: 1 num_heads: 2 head_size: 256 d_step: 50 d_epoch: 3 adv_d_step: 5 adv_d_epoch: 3 dis_embed_dim: 64 dis_hidden_dim: 64 num_rep: 64 log_file: log/log_0611_1726.txt save_root: save/relgan_vanilla_oracle_RSGAN_glr0.01_temp2_T0611-1726/ signal_file: run_signal.txt tips: vanilla SeqGAN

Starting Generator MLE Training... [MLE-GEN] epoch 0 : pre_loss = 7.8519, oracle_NLL = 10.1270, gen_NLL = 7.6967, [MLE-GEN] epoch 5 : pre_loss = 7.1450, oracle_NLL = 9.4151, gen_NLL = 6.9945, [MLE-GEN] epoch 10 : pre_loss = 6.7705, oracle_NLL = 9.2543, gen_NLL = 6.7171,

..... [MLE-DIS] d_step 47: d_loss = 0.0005, train_acc = 0.9999, eval_acc = 0.5068, [MLE-DIS] d_step 48: d_loss = 0.0018, train_acc = 0.9993, eval_acc = 0.5020, [MLE-DIS] d_step 49: d_loss = 0.0010, train_acc = 0.9997, eval_acc = 0.5010, Save pretrain_generator discriminator: pretrain/oracle_data/dis_pretrain_seqgan_vanilla.pt Starting Adversarial Training... Initial generator: oracle_NLL = 9.0884, gen_NLL = 5.9763,

ADV EPOCH 0

Traceback (most recent call last): File "main.py", line 100, in inst._run() File "/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/TextGAN-PyTorch/instructor/oracle_data/seqgan_instructor.py", line 72, in _run self.adv_train_generator(cfg.ADV_g_step) # Generator File "/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/TextGAN-PyTorch/instructor/oracle_data/seqgan_instructor.py", line 120, in adv_train_generator rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis) File "/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/TextGAN-PyTorch/utils/rollout.py", line 142, in get_reward samples = self.rollout_mc_search(sentences, given_num) File "/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/TextGAN-PyTorch/utils/rollout.py", line 36, in rollout_mc_search out, hidden = self.gen.forward(inp, hidden, need_hidden=True) File "/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/TextGAN-PyTorch/models/generator.py", line 44, in forward out = self.temperature * out # temperature RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.92 GiB total capacity; 9.41 GiB already allocated; 31.19 MiB free; 81.37 MiB cached) ub16c9@ub16c9-gpu:/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/TextGAN-PyTorch/run$

williamSYSU commented 5 years ago

The CUDA memory mainly occupied by the parameters of Rollout during Monte Carlo search. Thus, there're two ways to reduce CUDA memory usage: