tkipf / c-swm

Contrastive Learning of Structured World Models
MIT License
386 stars 67 forks source link

Discrepancy between the loss mentioned in the paper and GitHub #3

Open bhattg opened 4 years ago

bhattg commented 4 years ago

According to the paper, the negative component of the contrastive loss is the difference between the negative states (randomly sampled from embedding at timestamp t, (z{t}~)) and the ground truth state (z{t+1}).

However, as per the line 113 of modules.py, given no trans, you are effectively taking the difference between randomly sampled from embedding at timestamp t (z{t}~) and z{t} (rather than z_{t+1}).

` def contrastive_loss(self, obs, action, next_obs):

    objs = self.obj_extractor(obs)
    next_objs = self.obj_extractor(next_obs)

    state = self.obj_encoder(objs)
    next_state = self.obj_encoder(next_objs)

    # Sample negative state across episodes at random
    batch_size = state.size(0)
    perm = np.random.permutation(batch_size)
    neg_state = state[perm]

    self.pos_loss = self.energy(state, action, next_state)
    zeros = torch.zeros_like(self.pos_loss)

    self.pos_loss = self.pos_loss.mean()
    self.neg_loss = torch.max(
        zeros, self.hinge - self.energy(
            state, action, neg_state, no_trans=True)).mean()

    loss = self.pos_loss + self.neg_loss

    return loss

` Thus, I feel instead of the state as the first argument of the energy function, next_state should have been the argument. Please let me know if I am misconstruing at any point.

Thanks.

AugustKarlstedt commented 4 years ago

Good catch. I wonder if this would fix the issue described in Figure 4b.

bhattg commented 4 years ago

Hey, do you mean this in the caption of Figure 4 - "One trajectory (in the center) strongly deviates from typical trajectories seen during training, and the model struggles to predict the correct transition." ??

AugustKarlstedt commented 4 years ago

Yes, exactly.

BenchengY commented 4 years ago

I also wander why not apply transition_model to negative state