Div99 / IQ-Learn

(NeurIPS '21 Spotlight) IQ-Learn: Inverse Q-Learning for Imitation
https://div99.github.io/IQ-Learn/
Other
202 stars 31 forks source link

Critic function is diverging while using SAC #11

Closed mw9385 closed 1 year ago

mw9385 commented 1 year ago

Hi, Thank you for providing us a wonderful code. I am trying to adopt IQ method in my custom environment. However, I faced with diverging loss critic loss function. I tried to copy and paste the original code from github but this event is happening again and again. Is it a normal event if IQ imitation learning method is combined with SAC or am i using it in a wrong way. I uploaded my code with post. I also upload my loss function together. loss_function


class IQ(nn.Module):
    def __init__(self, args):
        super(IQ, self).__init__()

        self.args = args

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")        
        self.actor = Actor(self.args).to(self.device)

        self.q = Critic(self.args).to(self.device)
        # self.q_2 = Critic(self.args).to(self.device)

        self.target_q = Critic(self.args).to(self.device)
        # self.target_q_2 = Critic(self.args).to(self.device)

        self.soft_update(self.q, self.target_q, 1.)
        # self.soft_update(self.q_2, self.target_q_2, 1.)

        # self.alpha = nn.Parameter(torch.tensor(self.args.alpha_init))
        self.log_alpha = nn.Parameter(torch.log(torch.tensor(1e-3)))
        self.target_entropy = - torch.tensor(self.args.p_len * self.args.state_dim)

        self.q_optimizer = optim.Adam(self.q.parameters(), lr=self.args.q_lr)
        # self.q_2_optimizer = optim.Adam(self.q_2.parameters(), lr=self.args.q_lr)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=self.args.q_lr)

       # check directory
        isExist = os.path.exists(self.args.pretrain_model_dir)
        if not isExist:
            os.mkdir(self.args.pretrain_model_dir)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def get_action(self, depth, imu, dir_vector):
        # normalization
        depth, imu, dir_vector = self.normalization(depth, imu, dir_vector)
        mu,std = self.actor(depth, imu, dir_vector)
        std = std.exp()
        dist = Normal(mu, std)
        u = dist.rsample()
        u_log_prob = dist.log_prob(u)
        a = torch.tanh(u)
        a_log_prob = u_log_prob - torch.log(1 - torch.square(a) +1e-3)
        return a, a_log_prob.sum(-1, keepdim=True)

    def q_update(self, current_Q, current_v, next_v, done_masks, is_expert):
        #  calculate 1st term for IQ loss
        #  -E_(ρ_expert)[Q(s, a) - γV(s')]
        with torch.no_grad():
            y = (1 - done_masks) * self.args.gamma * next_v
        reward = (current_Q - y)[is_expert]

        # our proposed unbiased form for fixing kl divergence
        # 1st loss function     
        phi_grad = torch.exp(-reward)   
        loss = -(phi_grad * reward).mean()

        ######
        # sample using expert and policy states (works online)
        # E_(ρ)[V(s) - γV(s')], 2nd loss function
        value_loss = (current_v - y).mean()
        loss += value_loss        

        # Use χ2 divergence (calculate the regularization term for IQ loss using expert and policy states) (works online)
        reward = current_Q - y         

        # alpha value가 fixed 형태로 0.5로 설정되어 있음
        # chi2_loss = 1/(4 * self.alpha) * (reward**2).mean()
        chi2_loss = 1/(4 * 0.5) * (reward**2).mean()
        loss += chi2_loss
        ######
        return loss

    def train_network(self, writer, n_epi, train_memory):
        print("SAC UPDATE")
        depth, imu, dir_vector, actions, rewards, next_depth, next_imu, next_dir_vector, done_masks, is_expert = \
            self.get_samples(train_memory)
        q1, q2 = self.q(depth, imu, dir_vector, actions)    
        v1, v2 = self.getV(self.q, depth, imu, dir_vector)        

        with torch.no_grad():
            next_v1, next_v2 = self.get_targetV(self.target_q, next_depth, next_imu, next_dir_vector)            

        #q_update
        q1_loss = self.q_update(q1, v1, next_v1, done_masks, is_expert)        
        q2_loss = self.q_update(q2, v2, next_v2, done_masks, is_expert)        

        # define critic loss
        critic_loss = 1/2 * (q1_loss + q2_loss)
        # update
        self.q_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q.parameters(), 1.0)
        # step critic
        self.q_optimizer.step()

        ### actor update
        actor_loss,prob = self.actor_update(depth, imu, dir_vector)        
        ###alpha update
        # alpha_loss = self.alpha_update(prob)

        self.soft_update(self.q, self.target_q, self.args.soft_update_rate)
        self.soft_update(self.q_2, self.target_q_2, self.args.soft_update_rate)

        if writer != None:
            writer.add_scalar("loss/q_1", q1_loss, n_epi)
            writer.add_scalar("loss/q_2", q2_loss, n_epi)
            writer.add_scalar("loss/actor_loss", actor_loss, n_epi)
            writer.add_scalar("loss/alpha", alpha_loss, n_epi)
                # save model
        if np.mod(n_epi, self.args.save_period)==0 and n_epi > 0:
            # save models
            torch.save(self.actor.state_dict(), self.args.pretrain_model_dir + str('actor.pt'))            

    def actor_update(self, depth, imu, dir_vector):
        now_actions, now_action_log_prob = self.get_action(depth, imu, dir_vector)
        q_1, q_2 = self.q(depth, imu, dir_vector, now_actions)        
        q = torch.min(q_1, q_2)
        loss = (self.alpha.detach() * now_action_log_prob - q).mean()

        self.actor_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
        self.actor_optimizer.step()
        return loss,now_action_log_prob

    def alpha_update(self, now_action_log_prob):
        loss = (- self.alpha * (now_action_log_prob + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()    
        loss.backward()
        self.alpha_optimizer.step()
        return loss

    def soft_update(self, network, target_network, rate):
        for network_params, target_network_params in zip(network.parameters(), target_network.parameters()):
            target_network_params.data.copy_(target_network_params.data * (1.0 - rate) + network_params.data * rate)

    def get_expert_data(self):        
        # define train and validation dataset        
        self.expert_dataloader = DataLoader(True, self.args)     
        # load expert and training dataset
        expert_depth, expert_imu, expert_dir_vector, expert_action, reward, done, expert_next_depth, expert_next_imu, expert_next_dir_vector \
            = self.expert_dataloader.__getitem__(batch_size=self.args.discrim_batch_size)                                                
        # prepocessing training label        
        expert_action = self.label_preprocessing(expert_action)

        # convert numpy array into tensor
        expert_depth = torch.Tensor(expert_depth).cuda()
        expert_imu = torch.Tensor(expert_imu).cuda()
        expert_dir_vector = torch.Tensor(expert_dir_vector).cuda()
        expert_action = torch.Tensor(expert_action).cuda()

        expert_next_depth = torch.Tensor(expert_next_depth).cuda()
        expert_next_imu = torch.Tensor(expert_next_imu).cuda()
        expert_next_dir_vector = torch.Tensor(expert_next_dir_vector).cuda()

        return expert_depth, expert_imu, expert_dir_vector, expert_action, expert_next_depth, expert_next_imu, expert_next_dir_vector

    def getV(self, critic, depth, imu, dir_vector):
        action, log_prob = self.get_action(depth, imu, dir_vector)        
        current_Q1, current_Q2 = critic(depth, imu, dir_vector, action)
        current_V1 = current_Q1 - self.alpha.detach() * log_prob
        current_V2 = current_Q2 - self.alpha.detach() * log_prob
        return current_V1, current_V2

    def get_targetV(self, critic_target, depth, imu, dir_vector):
        action, log_prob = self.get_action(depth, imu, dir_vector)
        target_Q1, target_Q2 = critic_target(depth, imu, dir_vector, action)
        target_V1 = target_Q1 - self.alpha.detach() * log_prob
        target_V2 = target_Q2 - self.alpha.detach() * log_prob
        return target_V1, target_V2```
mw9385 commented 1 year ago

and these are my hyperparameters parser.add_argument("--memory_size", type=int, default=20000) parser.add_argument("--random_action", type=int, default=1000)#Don't need seeding for IL (Use 1000 for RL) parser.add_argument("--min_samples_to_start", type=int, default=1000) parser.add_argument("--alpha_init", type=float, default=0.5) parser.add_argument("--soft_update_rate", type=float, default=0.005) parser.add_argument("--mini_batch_size", type=int, default=128) parser.add_argument("--save_period", type=int, default=200) parser.add_argument("--gamma", type=float, default=0.99) parser.addargument("--lambda", type=float, default=0.95) parser.add_argument("--actor_lr", type=float, default=3e-5) parser.add_argument("--q_lr", type=float, default=3e-5) parser.add_argument("--actor_train_epoch", type=int, default=1)

Altriaex commented 1 year ago

Hi, I observed similar behaviors. In your code you set

reward = (current_Q - y)[is_expert]

and compute the chi2 regularization only for expert "reward" as

chi2_loss = 1/(4 0.5) (reward**2).mean()

which in my experience, lead to divergence. The reason is that, these "rewards" are in fact very large. So if you look at the iq.py file you see that the authors compute chi2 regularization on both the policy's and expert's "reward". In this case I do not have divergence problem, but I am still not able to get good policies, though.

Another thing to point out is that, the authors do not update alpha.

mw9385 commented 1 year ago

@Altriaex Thanks for your reply! Actually, I computed my reward using both expert and learner data set. In the first loss term, I set my reward as reward = (current_Q - y)[is_expert] and then, corresponding loss function is defined as: loss = -(reward)

In the chi2 regularization, I again set my reward as reward = (current_Q - y) and corresponding chi2_loss is defined as: chi2_loss = (4*0.5) * (reward)**2.mean(), which I already using both expert and learner data set.

Should I set my first reward from (current_Q-y)[is_expert] to current_Q - y for all loss terms or just apply current_Q - y term for chi2_loss?

I will try without updating alpha. And If you have any loss plot for your own custom environment, could you share it?

Many thanks.

Altriaex commented 1 year ago

I myself still cannot make this algo work, so I also don't what it the best thing to do.

mw9385 commented 1 year ago

Thanks. I will try without training alpha and let you know the results :)

mw9385 commented 1 year ago

@Div99 Hi, the divergence of critic function is a normal phenomenon in IQ learning? Or am I using the code in a wrong way? Thanks in advance :)

Div99 commented 1 year ago

Hi, sorry for the delay in replying back. I have observed that for continuous spaces you need to add the chi2 regularization on both the policy and the expert samples. The reason here is that you have a separate policy network in the continuous setting, and without also regularizing policy samples, we can learn large negative rewards for the policy that can diverge toward the negative infinity, preventing the method from converging.

For IQ-Learn on continuous spaces, I will recommend the setting method.regularize=True to enable the above behavior, try training using a single Q-network (instead of double critic) and try disabling alpha training and playing with small alpha values like 1e-3, 1e-2. If you are using the original code in the repo, you can try one of the settings used in our Mujoco experiments script run_mujoco.sh

For using automatic alpha training, you can see this issue: https://github.com/Div99/IQ-Learn/issues/5 In general, we want the imitation policy to have a very low entropy, as compared to SAC, and setting an entropy_target = -4 * dim(A) works well on most Mujoco environments when learning the alpha

Div99 commented 1 year ago

@Div99 Hi, the divergence of critic function is a normal phenomenon in IQ learning? Or am I using the code in a wrong way? Thanks in advance :)

No, the critic should not diverge if the method is working well. It is likely indicating a bug in the code or a wrong hyperparam setting

mw9385 commented 1 year ago

@Div99 Thanks for your reply. I was waiting for you! I am running my custom code in vision-based collision avoidance environment. The policy networks get visual inputs and produces collision-free trajectory (3 points in 2D space, dim(A) = 6). The policy network and critic network follows the same structure like the one used in Atari example.

I have tried with the following settings:

After training, I got these two loss functions: First one is actor loss and the other one is q loss. I still suffering from q function divergence. When I printed the q values, I could see some large negative values which result in huge q loss. I need to check whether I am using your method correctly again by implementing your original code. Or, can you guess any potential cause of divergence?

Actor loss

loss_actor_loss

Q loss

loss_q_1

Div99 commented 1 year ago

We use critic rate=3e-4 so that could be one source of divergence.

Also will recommend trying higher alpha like 1e-2 or 1e-1 if the above fix doesn't help. There also could be a potential issue on how the expert data is generated and whether it matches exactly with the policy data (obs normalization, etc.)

Altriaex commented 1 year ago

@mw9385 What about printing out your rewards? If you include the chi2 term, in theory you should have very small rewards, which should help you prevent divergence.

For me, it turns out that the key is to use single Q function as critic, as opposed to the SAC's double q solution.

Div99 commented 1 year ago

@mw9385 What about printing out your rewards? If you include the chi2 term, in theory you should have very small rewards, which should help you prevent divergence.

For me, it turns out that the key is to use single Q function as critic, as opposed to the SAC's double q solution.

Great! Glad the single q network worked, it's not clear why the double q trick works for SAC but not here, maybe the min prevents learning the correct rewards

mw9385 commented 1 year ago

@mw9385 What about printing out your rewards? If you include the chi2 term, in theory you should have very small rewards, which should help you prevent divergence.

For me, it turns out that the key is to use single Q function as critic, as opposed to the SAC's double q solution.

When I use double q networks, the reward values are in [-1, 1] range, which are not that high. I will try with single Q-function! Thanks you so much.

mw9385 commented 1 year ago

We use critic rate=3e-4 so that could be one source of divergence.

Also will recommend trying higher alpha like 1e-2 or 1e-1 if the above fix doesn't help. There also could be a potential issue on how the expert data is generated and whether it matches exactly with the policy data (obs normalization, etc.)

I will try with critic learning rate 3e-4 with single Q-networks. Also, set my initial alpha value as 1e-2. I will let you know my results. Also I will check my network input whether they are correctly normalized.

mw9385 commented 1 year ago

@Altriaex @Div99 Hi, I have tried with single Q critic and it works. I didn't see any divergence of critic loss. I ran the original code in the repo and my loss shows similar behavior. The reason of divergence is that the critic produces negative output (which means that the critic thinks that current states and actions are bad), and as the training iteration goes on q values become more and more negative resulting in divergence. Using single critic removes this dug.

Many thanks :)

mw9385 commented 1 year ago

@Div99 Sorry, I have to reopen the issue, because the loss function seems to be very unstable. It is fluctuating in a large magnitude. The followings are my hyperparameters:

스크린샷 2023-01-20 오후 10 52 36

스크린샷 2023-01-20 오후 10 52 43

mw9385 commented 1 year ago

I have solved this issue by tuning hyperparameters. Closing this issue.