for s_id, label_id in enumerate(label_ids):
for s_label in label_id.tolist():
value_len = len([x for x in s_label if x != 0])
slot_len = len([x for x in **slot_ids[s].tolist()** if x != 0])
combine_len = value_len + slot_len
if max_sv_len < combine_len:
max_sv_len = combine_len
mix_label_ids = []
for s_id, label_id in enumerate(label_ids):
tmp_label_id = torch.zeros([label_id.shape[0], max_sv_len], dtype=torch.long).to(self.device)
for elemnet_id, s_label in enumerate(label_id.tolist()):
value_element = [x for x in s_label if x != 0]
slot_element = [x for x in **slot_ids[s].tolist()** if x != 0]
combine_element = slot_element[0:-1] + value_element
tmp_label_id[elemnet_id][0:len(combine_element)] = torch.tensor(combine_element, dtype=torch.long).to(self.device)
mix_label_ids.append(tmp_label_id)
这里的的slot_ids[s]中的s一直是34这里有问题,我觉得这里的s应该是s_id