niiceMing / CMTA

(NIPS23)Contrastive Modules with Temporal Attention for Multi-Task Reinforcement Learning
5 stars 0 forks source link

single SAC baseline critic does not work #5

Closed Lan131 closed 1 month ago

Lan131 commented 1 month ago

File "/home/research/lanier.m/Desktop/CMTA/mtrl/agent/components/critic.py", line 495, in forward encoding, next_hidden, ori_encoding = self.encode(mtobs=temp_mtobs, detach=detach_encoder) ValueError: not enough values to unpack (expected 3, got 2)

If I try to bypass this by removing the ori_encoding eventually this fails: info_nce_loss = 0 info_nce = InfoNCE(negative_mode='paired')

batch_size, num_negative, embedding_size = 1280, 5, 64

    # experts_num
    for i in range(6):
        query = ori_encoding[i]
        positive_key = ori_next_encoding[i]
        # [5, batch, emb_size]
        negative_keys = torch.cat((ori_encoding[0:i],ori_encoding[i+1:6]),0)
        # [5, batch, emb_size] -> [batch, 5, emb_size]
        negative_keys = torch.transpose(negative_keys,0,1)
        info_nce_loss += info_nce(query, positive_key, negative_keys)
        # pdb.set_trace()
    logger.log("train/info_nce_loss", info_nce_loss, step) 
    total_loss = critic_loss + 2500 * info_nce_loss

    In sac.py because ori_encoding doesn't exist.
Lan131 commented 1 month ago

error occurs for soft_modu as well.

Lan131 commented 1 month ago

All of the baselines except care and CMTA have this issue.

niiceMing commented 1 month ago

The latest commit has fixed this bug.