helloacl / DST-DCPDS

code for task-oriented dialogue state tracking
1 stars 0 forks source link

域槽槽值连接代码有错误 #5

Closed shx5625958 closed 2 years ago

shx5625958 commented 2 years ago
        for s_id, label_id in enumerate(label_ids):
            for s_label in label_id.tolist():
                value_len = len([x for x in s_label if x != 0])
                slot_len = len([x for x in **slot_ids[s].tolist()** if x != 0])
                combine_len = value_len + slot_len
                if max_sv_len < combine_len:
                    max_sv_len = combine_len
        mix_label_ids = []
        for s_id, label_id in enumerate(label_ids):
            tmp_label_id = torch.zeros([label_id.shape[0], max_sv_len], dtype=torch.long).to(self.device)
            for elemnet_id, s_label in enumerate(label_id.tolist()):
                value_element = [x for x in s_label if x != 0]
                slot_element = [x for x in **slot_ids[s].tolist()** if x != 0]
                combine_element = slot_element[0:-1] + value_element
                tmp_label_id[elemnet_id][0:len(combine_element)] = torch.tensor(combine_element, dtype=torch.long).to(self.device)
            mix_label_ids.append(tmp_label_id)

这里的的slot_ids[s]中的s一直是34这里有问题,我觉得这里的s应该是s_id

helloacl commented 2 years ago

@shx5625958 thank you for your comment. The bug is already fixed. The wrong version ignores the information of slot name.