ethanfetaya / NRI

Neural relational inference for interacting systems - pytorch
MIT License
732 stars 157 forks source link

Where can I find the code of Eq. 12 in the paper?? #30

Open hoon0528 opened 3 years ago

hoon0528 commented 3 years ago

Below is the code snippet of MLPDecoder. I think prediction is ended with Eq. 11 in the paper. I can't find the code of Eq. 12. Am I missing something in this code??

Thanks in advance.

    def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send,
                            single_timestep_rel_type):

        # single_timestep_inputs has shape
        # [batch_size, num_timesteps, num_atoms, num_dims]

        # single_timestep_rel_type has shape:
        # [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types]

        # Node2edge 
        receivers = torch.matmul(rel_rec, single_timestep_inputs)
        senders = torch.matmul(rel_send, single_timestep_inputs)
        # Eq 10 [x_i^t, x_j^t] [#sims(batch_size), #tsteps_indexed, #edges, #dims*2]
        pre_msg = torch.cat([senders, receivers], dim=-1)
        # self.msg_out_shape = #node_features
        all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),
                                        pre_msg.size(2), self.msg_out_shape))
        if single_timestep_inputs.is_cuda:
            all_msgs = all_msgs.cuda()

        if self.skip_first_edge_type:
            start_idx = 1
        else:
            start_idx = 0

        # Run separate MLP for every edge type
        # NOTE: To exlude one edge type, simply offset range by 1
        # Eq 10 MLP
        for i in range(start_idx, len(self.msg_fc2)):
            msg = F.relu(self.msg_fc1[i](pre_msg))
            msg = F.dropout(msg, p=self.dropout_prob)
            msg = F.relu(self.msg_fc2[i](msg))
            msg = msg * single_timestep_rel_type[:, :, :, i:i + 1] #element-wise product with broadcast
            all_msgs += msg

        # Aggregate all msgs to receiver
        # Eq 11 / rel_rec [#edges, #nodes]
        agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous()

        # Skip connection
        aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1)

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob)
        pred = self.out_fc3(pred)

        # Predict position/velocity difference / Eq 11 >> Where is Eq 12??
        return single_timestep_inputs + pred

    def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1):
        # NOTE: Assumes that we have the same graph across all samples.
        # Input shape: [num_sims, num_atoms, num_timesteps, num_dims] > [#sims, #tsteps, #nodes, #dims]
        inputs = inputs.transpose(1, 2).contiguous()

        sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1),
                 rel_type.size(2)]
        rel_type = rel_type.unsqueeze(1).expand(sizes)

        time_steps = inputs.size(1)
        assert (pred_steps <= time_steps)
        preds = []

        # Only take n-th timesteps as starting points (n: pred_steps)
        last_pred = inputs[:, 0::pred_steps, :, :]
        curr_rel_type = rel_type[:, 0::pred_steps, :, :]
        # NOTE: Assumes rel_type is constant (i.e. same across all time steps).

        # Run n prediction steps / Eq 10~11
        for step in range(0, pred_steps):
            last_pred = self.single_step_forward(last_pred, rel_rec, rel_send,
                                                 curr_rel_type)
            preds.append(last_pred)

        sizes = [preds[0].size(0), preds[0].size(1) * pred_steps,
                 preds[0].size(2), preds[0].size(3)]

        output = Variable(torch.zeros(sizes))
        if inputs.is_cuda:
            output = output.cuda()

        # Re-assemble correct timeline
        for i in range(len(preds)):
            output[:, i::pred_steps, :, :] = preds[i]
        # last prediction is one step beyond input
        pred_all = output[:, :(inputs.size(1) - 1), :, :]

        return pred_all.transpose(1, 2).contiguous()
djoad001 commented 3 years ago

Hi I have maybe the same problem. I am looking for a way to predict new unseen trajectory steps but seems that something is missing.

Briefly, given x1, x2, ... xT with estimated edges and I would like to predict xT+1, ...xT+N.

Thank a lot for helping.

paulinesert commented 2 years ago

Hi, @djoad001 how did you end up doing? I also don't see how to predict iteratively different timesteps when I don't have access to the ground truth (and thus can not do the teacher forcing every N timesteps). So I guess the solution is to modify the code to take into account that we are in test mode and use the last mu (or draw a sample from the distribution) to predict the next one, all of it iteratively until we get the desired number of timesteps ?

Thanks for your help.