Wangpeiyi9979 / ESD

Code for NAACL2022 Long Paper "An Enhanced Span-based Decomposition Method for Few-Shot Sequence Labeling"
27 stars 2 forks source link

关于函数sequence_mask(sequence_length) #6

Closed MingHong1019 closed 2 years ago

MingHong1019 commented 2 years ago

https://github.com/Wangpeiyi9979/ESD/blob/d2c89249143c4703527610b89accc26bb1cd9157/model/utils.py#L7

在处理含有不同数目span的句子时,您是较少的span数目句子添加pad,然后再统计那些span是pad。 比如当前句子有12个span,batch句子中最大span数是14,那么就会把12个扩充成14个,最后2个是0。 然而这个函数在标记0的位置时,用“>”判断,utils.py第16行,是否应该换成">="呢,因为 seq_range_expand = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13] seq_length_expand = [12,12,12,12,12,12,12,12,12,12,12,12,12,12] 如果用“>”,会少标记一个0

Wangpeiyi9979 commented 2 years ago

您好,感谢指出,这里确实是应该使用>=。

MingHong1019 commented 2 years ago

可以把调用函数 https://github.com/Wangpeiyi9979/ESD/blob/master/model/ESD.py#L106 换成 support_span_mask = (1-support_is_padding).bool().view(len(support_span_nums),-1) 吧?

Wangpeiyi9979 commented 2 years ago

嗯嗯,应该是可以的。