MishaLaskin / curl

CURL: Contrastive Unsupervised Representation Learning for Sample-Efficient Reinforcement Learning
MIT License
561 stars 88 forks source link

An error when computing CURL #12

Closed DanielS684 closed 3 years ago

DanielS684 commented 3 years ago

I noticed when reading through the paper and the code that your pseudocode in the paper says that the key encoder needs to be detached from the graph but in your actual code you don't set detach = True for z_pos = self.CURL.encode(obs_pos, ema=True). I wanted to know whether the paper or code is correct. Or maybe I am missing some part of the computation.

This is what is in the code for curl_sac.py:

def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step):

        z_a = self.CURL.encode(obs_anchor) 
        z_pos = self.CURL.encode(obs_pos, ema=True)

        logits = self.CURL.compute_logits(z_a, z_pos)
        labels = torch.arange(logits.shape[0]).long().to(self.device)
        loss = self.cross_entropy_loss(logits, labels)

        self.encoder_optimizer.zero_grad()
        self.cpc_optimizer.zero_grad()
        loss.backward()

        self.encoder_optimizer.step()
        self.cpc_optimizer.step()
        if step % self.log_interval == 0:
            L.log('train/curl_loss', loss, step)

and this is what is in the pseudocode for the paper:

for x in loader: 
    x_q = aug(x)
    x_k = aug(x)
    z_q = f_q.forward(x_q)
    z_k = f_k.forward(x_k)
    z_k = z_k.detach()
    proj_k = matmul(W, z_k.T)
    logits = matmul(z_q, proj_k)
    logits = logits - max(logits, axis=1)
    labels = arange(logits.shape[0])
    loss = CrossEntropyLoss(logits, labels)
    loss.backward()
    update(f_q.params)
    update(W)
    f_k.params = m*f_k.params+(1-m)*f_q.params
MishaLaskin commented 3 years ago

Take a look at the encode method, there's a stop_grad when ema is set to True

DanielS684 commented 3 years ago

@MishaLaskin Yeah just read through the code again and realized my mistake since it was using the encoder from the critic_target instead. Really awesome work though