HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.23k stars 236 forks source link

Confused that the process of "Trains the generator to maximize the discriminator loss" #694

Open Liuzy0908 opened 1 year ago

Liuzy0908 commented 1 year ago

Problem

Hi, the imitation is a great project!

Currently, I am training GAIL algorithm, and the learner network is PPO in SB3. I have questions about the training process for imitation\GAIL\train_gen

def train():
    for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
        self.train_gen(self.gen_train_timesteps)                                    ######## Confused!!!
        for _ in range(self.n_disc_updates_per_round):
            with networks.training(self.reward_train):
                # switch to training mode (affects dropout, normalization)
                self.train_disc()

In the above source code, the self.train_gen calls the learn function in SB3\on_policy_algorithm. The learn function in SB3\on_policy_algorithm updates the PPO's actor network and critic network.

So I am confused that the training of the generator is no different from that of the PPO, and this process has nothing to do with the discriminator.

Can you explain how it work that "Trains the generator to maximize the discriminator loss" in imitation.algorithms.adversarial.gail.GAIL\train_gen in the code of imitation?

Looking forward to your reply!

Sincerely

Liuzy0908 commented 1 year ago

Generally, the generator's loss is the negative of the discriminator's score for state and action(i.e. G_loss = - D(s,a)), but I find that the generator (SB3's PPO actor net) is still optimized by SB3's PPO critic net in imitation\gail.

https://github.com/HumanCompatibleAI/imitation/blob/df2627446b7457758a0b09fa74a1fcb19403a236/src/imitation/algorithms/adversarial/common.py#L447-L452

In other words, I don't find the generator's G_loss= -1* discriminator(i.e. G_loss = - D(s,a)) and then the G_loss.backward().

I saw the unclear description in https://github.com/HumanCompatibleAI/imitation/issues/635#issuecomment-1329630844. Can you explain exactly how the generator in imitation\gail is updated by the discriminator?

This is important to me. Thank you for your time. @AdamGleave @ernestum @shwang @dfilan

ThomasRochefortB commented 1 year ago

@Liuzy0908 , Here is how I understand it:

The environment is wrapped when you instantiate GAIL so that the reward used in the .learn method corresponds to the GAIL generator objective. See here:

https://github.com/HumanCompatibleAI/imitation/blob/df2627446b7457758a0b09fa74a1fcb19403a236/src/imitation/algorithms/adversarial/common.py#L225C6-L233C14

So even though we are still using the PPO .learn method, we are doing a policy gradient step on the GAIL objective function