Closed w82318029 closed 4 years ago
Have a look at these flags - if you set --lm_weight 0 --qa_weight 0 --disc_weight 1.0 --disc_train
that should do what you're looking for.
Thank you for your quick response.
One more question. I try to run your code with my own data. so i need to train the discriminator from scratch. how to get word_embedding, char_embedding, word_dictionary and char_dictionary ? Thank you !
To be honest, I can't remember - but the discriminator was based on this QANet implementation. I think you'd need to follow the instructions on there. Or just use the embeddings/vocab I provide.
thank you!
sorry to bother you again. I set --lm_weight 0 --qa_weight 0 --disc_weight 1.0 --disc_train. I got error message after go through 215 step. I try with setting --lm_weight 0.25 --qa_weight 0.5 --disc_weight 1.0 --disc_train, I got same error message at 215 steps. Have you encountered same error?
Model type is RL-S2S Loaded SQUAD with 75722 triples Modifying seq2seq model to incorporate rl rewards Building and loading lm Building and loading qa model Total number of trainable parameters: 788673 Total number of trainable parameters: 33135137 Total number of trainable parameters: 806785 loading wangyanmeng ./models/saved/discriminator Loading discriminator from ./models/saved/discriminator/model.checkpoint-4000 Traceback (most recent call last): File "/〇pt/conda/lib/python3.6/site-packages/tensorf1ow/python/cTient/session.py", line 1334, in _do_call return fn(*args) File "/〇pt/conda/lib/python3.6/site-packages/tensorf1ow/python/cTient/session.py", line 1319, in _run_fn options, feed_dict, fetch_list, target_Tist, run_metadata) File "/〇pt/conda/lib/python3.6/site-packages/tensorf1ow/python/cTient/session.py", line 1407, in _can_tf_sessionrun run_metadata) tensorflow.python.framework.errors_impl.invalidArgumentError: Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
It looks like error occur at the moment of loading discriminator. but it actually occurs when running to 215th step
Does it work with the following flags set?
--lm_weight 0.25 --qa_weight 0.5 --disc_weight 1.0 --nodisc_train
--lm_weight 0.25 --qa_weight 0.5 --disc_weight 0.0 --nodisc_train
The error mentions empty tensors, which sounds like there might be an issue with your dataset?
Thank you for your suggestions. On squad1.1 datasets, it didnot work yesterday. Then, I tried --lm_weight 0.25 --qa_weight 0.5 --disc_weight 1.0 --nodisc_train it worked fine. And tried --lm_weight 0.25 --qa_weight 0.5 --disc_weight 1.0 --disc_train it also worked fine. maybe becasue “--nodisc_train“ changed something? saved a disc model ?
I'm not sure why that would help - but glad it's working now! The discriminator code was fairly late in the project so it's definitely not as robust as the rest.
Thank you for your adivces. Would you please tell me how to fine tuning the a pretrained seq2seq network with only adversarial discriminator ? set flag restore=True ? anything else? Thanks a lot !
Have a look at these flags - if you set
--lm_weight 0 --qa_weight 0 --disc_weight 1.0 --disc_train
that should do what you're looking for.
This set of flags will disable the LM/QA rewards, just use the discriminator, and allow the discriminator to be trainable (so that it becomes adversarial).
Sorry to bother you again. I got this error again "Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero". it seems like it related to low gpu memory. I use a v100 gpu with 16G memory. the code takes almost all of it. how much memory it takes when you run this code ? thanks a lot. any advice ?
Ah yes good point - you may need to reduce the batch size with --batch_size 8
or similar to get everything to fit in memory, the default batch size is for without the extra models.
Thank you for the suggestion ! It works fine now.
Thanks for sharing your code! I try to run your code with only adversarial discriminator. Would you please info me how to disable qa, lm and discriminator rewards? Thank you !