urbanmobility / CSLSL

PyTorch implementation of the paper-"Human Mobility Prediction with Causal and Spatial-constrained Multi-task Network"
MIT License
8 stars 0 forks source link

Maybe something wrong in model.py #3

Open K-King6 opened 2 weeks ago

K-King6 commented 2 weeks ago

cur_t_rnn, hc_t = self.capturer_t(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:]) if self.cat_contained: cur_c_rnn, hc_c = self.capturer_c(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_t) cur_l_rnn, hc_l = self.capturer_l(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_c)

        # 4) tower, t,c,l
        # CMTL
        hc_t, hc_c, hc_l = hc_t.squeeze(), hc_c.squeeze(), hc_l.squeeze()

        c_pred = self.fc_c(hc_c) 
        c_trans = self.label_trans_c(c_pred.clone())
        t_pred = self.fc_t(torch.cat((hc_t, c_trans), dim=-1)) 
        t_trans = self.label_trans_t(t_pred.clone())
        l_pred = self.fc_l(torch.cat((hc_l, t_trans), dim=-1))

You first calculate hc_t and use it to calculate hc_c,but you then first calculate c_trans and use it to calculate t_pred, it seems not consistent and may make your result worse. Snipaste_2024-06-20_20-27-54

herozen97 commented 2 weeks ago

Thank you for pointing out this issue, which brought to our attention that the published version is the variant of "what -> when -> where". If you intend to use it as a baseline, we appreciate it if you could revert it to the 'when -> what -> where' version proposed in the paper. Thank you once again and best wishes.