explosion / spaCy

💫 Industrial-strength Natural Language Processing (NLP) in Python
https://spacy.io
MIT License
30.39k stars 4.42k forks source link

TRF Alignment wrong? #7032

Closed maxtrem closed 3 years ago

maxtrem commented 3 years ago

How to reproduce the behaviour

I'm trying to extract trf vectors for certain spans in spacy and encountered several IndexErrors on the way. Here is one example:

import spacy
nlp = spacy.load('en_core_web_trf')
spacy.__version__ 
# '3.0.1'

text = "Almost daily something is reported which feeds this Catholic hope in England : statistics of the increasing numbers of converts and Irish Catholic immigrants ; news of a Protestant minister in Leamington who has offered to allow a Catholic priest to preach from his pulpit ; a report that a Catholic nun had been requested to teach in a non-Catholic secondary school during the sickness of one of its masters ; the startling statement in a respectable periodical that `` Catholics , if the present system is still in operation , will constitute almost one-third of the House of Lords in the next generation '' ; a report that 200 Protestant clergymen and laity attended a votive Mass offered for Christian unity at a Catholic church in Slough during the Church Unity Octave ."
doc = nlp(text)

doc[130]
# church

doc._.trf_data.align[130]
#Ragged(data=array([[174]], dtype=int32), lengths=array([1], dtype=int32), data_shape=(-1, 1), cumsums=None)

As I understand it the last output (174) should be the alignment with the wordpieces and trf vectors. However this raises an IndexError as both wordpieces and vectors are significantly shorter.

len(doc._.trf_data.wordpieces.strings[0])
#136

doc._.trf_data.tensors[0].shape
#(2, 136, 768)

Is this a bug or intentional? If the latter is the case, how could I extract the correct alignment?

Thanks!


Here are the full wordpieces and alignments:

WordpieceBatch(strings=[['<s>', 'Almost', 'Ä daily', 'Ä something', 'Ä is', 'Ä reported', 'Ä which', 'Ä feeds', 'Ä this', 'Ä Catholic', 'Ä hope', 'Ä in', 'Ä England', 'Ä :', 'Ä statistics', 'Ä of', 'Ä the', 'Ä increasing', 'Ä numbers', 'Ä of', 'Ä converts', 'Ä and', 'Ä Irish', 'Ä Catholic', 'Ä immigrants', 'Ä ;', 'Ä news', 'Ä of', 'Ä a', 'Ä Protestant', 'Ä minister', 'Ä in', 'Ä Le', 'aming', 'ton', 'Ä who', 'Ä has', 'Ä offered', 'Ä to', 'Ä allow', 'Ä a', 'Ä Catholic', 'Ä priest', 'Ä to', 'Ä preach', 'Ä from', 'Ä his', 'Ä pul', 'pit', 'Ä ;', 'Ä a', 'Ä report', 'Ä that', 'Ä a', 'Ä Catholic', 'Ä nun', 'Ä had', 'Ä been', 'Ä requested', 'Ä to', 'Ä teach', 'Ä in', 'Ä a', 'Ä non', '-', 'Catholic', 'Ä secondary', 'Ä school', 'Ä during', 'Ä the', 'Ä sickness', 'Ä of', 'Ä one', 'Ä of', 'Ä its', 'Ä masters', 'Ä ;', 'Ä the', 'Ä startling', 'Ä statement', 'Ä in', 'Ä a', 'Ä respectable', 'Ä period', 'ical', 'Ä that', 'Ä ``', 'Ä Catholics', 'Ä ,', 'Ä if', 'Ä the', 'Ä present', 'Ä system', 'Ä is', 'Ä still', 'Ä in', 'Ä operation', 'Ä ,', 'Ä will', 'Ä constitute', 'Ä almost', 'Ä one', '-', 'third', 'Ä of', 'Ä the', 'Ä House', 'Ä of', 'Ä Lords', 'Ä in', 'Ä the', 'Ä next', 'Ä generation', "Ä ''", 'Ä ;', 'Ä a', 'Ä report', 'Ä that', 'Ä 200', 'Ä Protestant', 'Ä clergy', 'men', 'Ä and', 'Ä la', 'ity', 'Ä attended', 'Ä a', 'Ä vot', 'ive', 'Ä Mass', 'Ä offered', 'Ä for', 'Ä Christian', 'Ä unity', 'Ä at', '</s>'], ['<s>', 'almost', 'Ä one', '-', 'third', 'Ä of', 'Ä the', 'Ä House', 'Ä of', 'Ä Lords', 'Ä in', 'Ä the', 'Ä next', 'Ä generation', "Ä ''", 'Ä ;', 'Ä a', 'Ä report', 'Ä that', 'Ä 200', 'Ä Protestant', 'Ä clergy', 'men', 'Ä and', 'Ä la', 'ity', 'Ä attended', 'Ä a', 'Ä vot', 'ive', 'Ä Mass', 'Ä offered', 'Ä for', 'Ä Christian', 'Ä unity', 'Ä at', 'Ä a', 'Ä Catholic', 'Ä church', 'Ä in', 'Ä Sl', 'ough', 'Ä during', 'Ä the', 'Ä Church', 'Ä Unity', 'Ä Oct', 'ave', 'Ä .', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']], input_ids=array([[    0, 32136,  1230,   402,    16,   431,    61, 17456,    42,
         4019,  1034,    11,  1156,  4832,  6732,     9,     5,  2284,
         1530,     9, 33894,     8,  3445,  4019,  4175, 25606,   340,
            9,    10, 33478,  1269,    11,  1063,  9708,  1054,    54,
           34,  1661,     7,  1157,    10,  4019, 13174,     7, 31055,
           31,    39, 25578, 17291, 25606,    10,   266,    14,    10,
         4019, 20172,    56,    57,  5372,     7,  6396,    11,    10,
          786,    12, 42799,  5929,   334,   148,     5, 25231,     9,
           65,     9,    63, 22337, 25606,     5, 26556,   445,    11,
           10, 25031,   675,  3569,    14, 45518, 22509,  2156,   114,
            5,  1455,   467,    16,   202,    11,  2513,  2156,    40,
        14409,   818,    65,    12, 12347,     9,     5,   446,     9,
        26608,    11,     5,   220,  2706, 12801, 25606,    10,   266,
           14,  1878, 33478, 21064,  2262,     8,   897,  1571,  2922,
           10, 19314,  2088,  5370,  1661,    13,  2412,  8618,    23,
            2],
       [    0, 26949,    65,    12, 12347,     9,     5,   446,     9,
        26608,    11,     5,   220,  2706, 12801, 25606,    10,   266,
           14,  1878, 33478, 21064,  2262,     8,   897,  1571,  2922,
           10, 19314,  2088,  5370,  1661,    13,  2412,  8618,    23,
           10,  4019,  2352,    11,  4424,  4894,   148,     5,  2197,
        19573,  1700,  4097,   479,     2,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1]]), attention_mask=array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0]]), lengths=[136, 50], token_type_ids=None)

Ragged(data=array([[  1],
       [  2],
       [  3],
       [  4],
       [  5],
       [  6],
       [  7],
       [  8],
       [  9],
       [ 10],
       [ 11],
       [ 12],
       [ 13],
       [ 14],
       [ 15],
       [ 16],
       [ 17],
       [ 18],
       [ 19],
       [ 20],
       [ 21],
       [ 22],
       [ 23],
       [ 24],
       [ 25],
       [ 26],
       [ 27],
       [ 28],
       [ 29],
       [ 30],
       [ 31],
       [ 32],
       [ 33],
       [ 34],
       [ 35],
       [ 36],
       [ 37],
       [ 38],
       [ 39],
       [ 40],
       [ 41],
       [ 42],
       [ 43],
       [ 44],
       [ 45],
       [ 46],
       [ 47],
       [ 48],
       [ 49],
       [ 50],
       [ 51],
       [ 52],
       [ 53],
       [ 54],
       [ 55],
       [ 56],
       [ 57],
       [ 58],
       [ 59],
       [ 60],
       [ 61],
       [ 62],
       [ 63],
       [ 64],
       [ 65],
       [ 66],
       [ 67],
       [ 68],
       [ 69],
       [ 70],
       [ 71],
       [ 72],
       [ 73],
       [ 74],
       [ 75],
       [ 76],
       [ 77],
       [ 78],
       [ 79],
       [ 80],
       [ 81],
       [ 82],
       [ 83],
       [ 84],
       [ 85],
       [ 86],
       [ 86],
       [ 87],
       [ 88],
       [ 89],
       [ 90],
       [ 91],
       [ 92],
       [ 93],
       [ 94],
       [ 95],
       [ 96],
       [ 97],
       [ 98],
       [ 99],
       [100],
       [137],
       [101],
       [138],
       [102],
       [139],
       [103],
       [140],
       [104],
       [141],
       [105],
       [142],
       [106],
       [143],
       [107],
       [144],
       [108],
       [145],
       [109],
       [146],
       [110],
       [147],
       [111],
       [148],
       [112],
       [149],
       [113],
       [150],
       [114],
       [151],
       [115],
       [152],
       [116],
       [153],
       [117],
       [154],
       [118],
       [155],
       [119],
       [156],
       [120],
       [121],
       [157],
       [158],
       [122],
       [159],
       [123],
       [124],
       [160],
       [161],
       [125],
       [162],
       [126],
       [163],
       [127],
       [128],
       [164],
       [165],
       [129],
       [166],
       [130],
       [167],
       [131],
       [168],
       [132],
       [169],
       [133],
       [170],
       [134],
       [171],
       [172],
       [173],
       [174],
       [175],
       [176],
       [177],
       [178],
       [179],
       [180],
       [181],
       [182],
       [183],
       [184]], dtype=int32), lengths=array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 4, 2, 4, 2, 2, 4, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1,
       2, 1, 1, 1, 1, 2, 1], dtype=int32), data_shape=(-1,), cumsums=array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  34,  35,  36,  37,  38,  39,  40,  41,
        42,  43,  44,  45,  46,  48,  49,  50,  51,  52,  53,  54,  55,
        56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,
        69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,
        82,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
        96,  97,  98,  99, 100, 102, 104, 106, 108, 110, 112, 114, 116,
       118, 120, 122, 124, 126, 128, 130, 132, 134, 136, 138, 140, 144,
       146, 150, 152, 154, 158, 160, 162, 164, 166, 168, 170, 171, 172,
       173, 174, 176, 177, 178, 179, 180, 182, 183]))

Your Environment

honnibal commented 3 years ago

Hi @maxtrem,

In order to handle long inputs, we extract potentially overlapping spans from the Doc, and pass those into the transformer. So let's say you have a doc of 20 tokens and a window of 15, and a stride of 10. This will cut the doc into a slice of 15 and a slice of 10. If the transformer has a width of 768, we'll get a tensor of (2, 15, 768) (after padding). The alignment points into a table that refers to a 2D reshaped version of that tensor, so you'll see indices up to 30.

So what you'll need to do is either flatten the nested doc._.trf_data.strings list, so that you can point into it with the index, or map the index back to two dimensions like this:

seq_size = len(doc._.trf_data.strings[-1])
batch = index // seq_size
item = index % seq_size
string = strings[batch][item]
maxtrem commented 3 years ago

Ah okay, it makes perfectly sense now. I was actually wondering why there is a 2 in the first dimension in the tensor shape(2, 136, 768).

I think I'm well served with this explanation, thanks a lot!

github-actions[bot] commented 3 years ago

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.