CR-Gjx / LeakGAN

The codes of paper "Long Text Generation via Adversarial Training with Leaked Information" on AAAI 2018. Text generation using GAN and Hierarchical Reinforcement Learning.
https://arxiv.org/abs/1709.08624
576 stars 182 forks source link

Improve LeakGAN by Changing Policy Gradient Structure #2

Open NickShahML opened 7 years ago

NickShahML commented 7 years ago

Hey @CR-Gjx Thanks for providing this open source code. Very helpful to study and I love the idea of hierarchical reinforcement learning.

In the recent AlphaGo Zero paper and Thinking Fast and Slow Paper, they both show that replacing classic policy gradient with MCTS guided gradients reaches far better results and is more stable.

In these papers, the rf problems they address are self-play, but I believe that their techniques could be applied to LeakGAN and it would improve its performance substantially.

Currently, LeakGAN takes a sequence and calculates a scalar reward. This scalar reward is then used in REINFORCE to improve the worker. We would leave the manager's objective the same.

Instead of having just one correct target, the papers suggest having multiple correct targets (distribution). There is a reward for each target within the distribution. Thus when you do cross entropy, you are doing it over n targets each with their respective rewards.

To generate individual rewards for each target, MCTS is used to improve upon the original policy. They use action = Q(s,a) + U(s,a) to create the decision tree. They then use the number of visits to calculate the reward (rather than a value function because value function leads to overfitting).

The fundamental difference here is that we are optimizing a distribution rather than a single target. This distribution naturally has way more information for the Generator to benefit from. I think this would help immensely help with mode collapse which is partially remedied by occasionally training with MLE. Thoughts on this?

CR-Gjx commented 7 years ago

Thanks for your suggestion! We also notice these thoughts and try MCTS in our framework, but in text generation, the counts of action always more than 5000 while 361 in GO, so the algorithm is limited by the memory of the GPU. I think that it can be a great idea if you have enough resource.

NickShahML commented 7 years ago

Yes, I should have addressed that issue. I tested your repo on a larger vocab size (80k) and I run out of memory quickly. However, I think there are several ways to address this memory issue:

  1. The biggest memory problem with your approach is that LSTM's take an excessive amount of memory. On the other hand, with the Transformer Network, you not only experience improved results over vanilla LSTM, but it takes significantly less memory. On a single 1080ti, I can train a batch size of 2048 while for a comparable LSTM, at most I can train a batch size of 64.

Another huge benefit with this network is that it can be trained in linear time, which means you can reduce the batch size even further (this would affect ranking part of your algo through)

In your paper, you use a small LSTM network of 128 units. If you used a comparable Transformer network (just the decoder portion), you would have little to no memory footprint.

  1. Yes, there is action counts over 5000+ actions in text generation, but one way to reduce this problem is to use subwords instead of word-based. You can get reasonable text generation with a 4k vocab.

  2. Finally, I have a four 1080ti system that I would be happy to run any experiments you guys have. Additionally AWS just released volta gpus for renting.

AranKomat commented 6 years ago

I'm working on implementing this AlphaZero + GAN + Transformer thing. A good thinWas your comparison g about our case compared with board game is that the required forward FLOPs at each move is smaller by ~100 fold, since the input of each layer in our case is bs x hidden_dim, whereas in Go it is bs x hidden_dim x 19 x 19. Furthermore, we can have less number of layers (e.g. 6) and drastically decrease the number of simulations per move, the latter of which I have some justification for. Thus, I believe you can do a reasonable training with one or several GPUs without decreasing the hidden dimension from 256. For simplicity, I've omitted leakage of information and hierarchical components to compare with SeqGAN. I'm not confident in the discriminability of unfinished sentences, so I'll try two cases: (1) to assign the D score of (finished) sentence to each leaf (no non-leaf node); (2) to assign the D score of any sentence to any node and z of a node is the mean of the child nodes. Without using proper cache, Transformer's inference is much slower than LSTM, whereas with cache it can perform fast decoding like faster Wavenet, which makes it slightly faster than LSTM. In my case, both G and D are Transformer.

@NickShahML I don't see why Transformer is 32x more memory-efficient than LSTM, since the most memory is consumed at the embedding and the softmax layer, which are identical to the both architectures. How did you make the comparison? Batch size used in T2T's Transformer implementation corresponds to the total number of tokens used rather than the total number of sentences. Is your 2048 batch size of Transformer really the same thing as a batch of 2048 sentences?

AranKomat commented 6 years ago

So, I've completed the aforementioned implementation and hyperparameter tuning, and I'm trying to achieve the full convergence now with ImageCOCO. I've detected a significant mode collapsing of LeakGAN on ImageCOCO. For example, according to generated_coco_examples.txt, the word "skateboard" appears 3261 times over nearly 10k sentences, but it didn't appear very often in the actual dataset. A similar thing can be said about other words such as "A" and "man". This can be attributed to the small generator and REINFORCE. AlphaZero allows a larger architecture for generator, so hopefully this issue will be mitigated.

CR-Gjx commented 6 years ago

I'm doing some work for aforementioned problems, and I think that there are many works to do, we can share some progress to solve these problems~