octanove / shiba

Pytorch implementation and pre-trained Japanese model for CANINE, the efficient character-level transformer.
Other
90 stars 14 forks source link

Masking: tensor shape mismatch #1

Closed stefan-it closed 2 years ago

stefan-it commented 3 years ago

Hi :hugs:

I'm currently trying to pre-train a new model on my own data (German, not Japanese). I followed the pre-processing steps and created the jsonl file.

When using the rand_char as masking strategy, after 733 steps (with a batch size of 11 on a single V100), the following error message is thrown:

Traceback (most recent call last):                                      
  File "training/train.py", line 59, in <module>
    main()                   
  File "training/train.py", line 55, in main
    trainer.train(resume_from_checkpoint=checkpoint_dir)
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1246, in train
    for step, inputs in enumerate(epoch_iterator):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 518, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 558, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/mnt/shiba/training/masking.py", line 285, in __call__
    input_ids, labels, masked_indices = random_mask(padded_batch['input_ids'],
  File "/mnt/shiba/training/masking.py", line 55, in random_mask
    indices_to_mask = torch.stack(indices_to_mask)
RuntimeError: stack expects each tensor to be equal size, but got [258] at entry 0 and [252] at entry 10

When using rand_span, it also throws an error message:

Traceback (most recent call last):                                                                                                                                                          
  File "training/train.py", line 59, in <module>                                                                                                                                            
    main()                                                                                                                                                                                  
  File "training/train.py", line 55, in main                                                                                                                                                
    trainer.train(resume_from_checkpoint=checkpoint_dir)                                                                                                                                    
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1246, in train                                                                                                
    for step, inputs in enumerate(epoch_iterator):                                                                                                                                          
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 518, in __next__                                                                                       
    data = self._next_data()                                                                                                                                                                
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 558, in _next_data                                                                                     
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration                           
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch                                   
    return self.collate_fn(data)                                                                                                                                                            
  File "/mnt/shiba/training/masking.py", line 264, in __call__                                                                               
    input_ids, labels, masked_indices = random_span_mask(padded_batch['input_ids'],                                                                                                         
  File "/mnt/shiba/training/masking.py", line 193, in random_span_mask                                                                                                                      
    all_masked_indices = torch.stack(masked_indices_per_row)                                                                                                                                
RuntimeError: stack expects each tensor to be equal size, but got [258] at entry 0 and [250] at entry 10 

Also after 733 steps. Do you have any explanation for this behavior :thinking:

Many thanks!

Mindful commented 3 years ago

Hey Stefan,

So both of these methods are crashing for the same reason, which is not being able to mask the same number of characters for two different sequences in the same batch. I thought I had fixed this, but it may have gotten un-fixed when I rewrote some things recently. I can probably find time to fix this sometime in the next week or so, but I have a few things I want to clarify first:

Then, in regard to your issue in particular:

stefan-it commented 3 years ago

Hi @Mindful ,

thanks for your help :hugs:

So the input format (as far as I understand the hints in the training sheet) that I use is one sentence per line, such as:

Nr.
SI
Ausgabe ' -Lt
Freitag.
19. April 1918
SB.
Jahrgang
Erscheint «w ave» Wochentagen ««chmittag» zwischen 4 a.
7 Rvr.
cha«»t Er>editi»« ««d Stedattto» Hamburg.

I also used blank lines to mark document boundaries. I was using my own sentence-splitting routine (it basically comes from NLTK, so I was not using the preprocess.sh script).

And yeah, the input text comes from an OCR pipeline 😅 , then I used the to_examples.py script to create the jsonl file; it looks like:

{"input_ids": [57344, 100, 46, 105, 114, 99, 104, 32, 100, 97, 187, 32, 66, 114, 97, 110, 110, 116, 119, 101, 105, 110, 32, 77, 111, 110, 111, 112, 111, 108, 44, 32, 100, 101, 110, 32, 196, 97, 102, 116, 101, 101, 122, 111, 108, 108, 44, 32, 100, 105, 101, 32, 108, 108, 105, 110, 39, 97, 116, 122, 115, 116, 101, 117, 101, 114, 44, 32, 103, 101, 110, 44, 101, 105, 117, 104, 105, 110, 32, 119, 105, 101, 32, 101, 116, 119, 97, 32, 97, 109, 116, 105, 32, 100, 105, 105, 114, 99, 104, 32, 100, 105, 101, 32, 80, 32, 187, 115, 116, 32, 107, 97, 114, 116, 101, 110, 32, 187, 32, 115, 116, 32, 101, 117, 101, 114, 32, 60, 100, 105, 101, 32, 80, 111, 105, 116, 107, 110, 114, 105, 101, 32, 115, 116, 101, 104, 116, 32, 98, 101, 105, 32, 46, 57345, 39, 101, 114, 32, 77, 46, 39, 115, 115, 101, 32, 109, 101, 104, 114, 32, 105, 110, 32, 71, 117, 110, 115, 116, 32, 97, 108, 39, 59, 32, 98, 101, 105, 32, 100, 101, 110, 32, 66, 99, 252, 98, 99, 110, 100, 101, 110, 41, 32, 115, 105, 99, 104, 32, 100, 105, 101, 32, 117, 110, 116, 101, 114, 171, 187, 32, 83, 99, 104, 105, 99, 104, 116, 101, 110, 32, 98, 101, 115, 111, 110, 100, 101, 114, 187, 32, 115, 99, 104, 119, 101, 114, 32, 103, 101, 116, 114, 111, 102, 102, 101, 110, 32, 102, 252, 108, 108, 101, 110, 32, 119, 101, 114, 100, 101, 110, 46, 57345, 68, 105, 101, 32, 65, 117, 115, 115, 105, 99, 104, 116, 32, 97, 117, 105, 32, 100, 105, 101, 32, 68, 117, 114, 99, 104, 115, 101, 116, 122, 117, 110, 103, 32, 100, 101, 114, 32, 83, 116, 101, 117, 101, 114, 45, 32, 80, 108, 228, 110, 101, 32, 105, 110, 32, 100, 101, 114, 32, 118, 111, 114, 108, 105, 101, 103, 101, 110, 100, 101, 110, 32, 71, 101, 115, 116, 97, 108, 116, 32, 107, 97, 110, 110, 32, 100, 97, 110, 97, 99, 104, 32, 110, 105, 99, 104, 116, 32, 104, 111, 99, 104, 32, 97, 110, 103, 101, 115, 99, 104, 108, 97, 103, 101, 110, 32, 119, 101, 114, 100, 101, 110, 46, 57345, 69, 105, 110, 32, 110, 105, 99, 104, 116, 32, 103, 101, 114, 105, 110, 103, 101, 187, 32, 77, 97, 223, 32, 118, 111, 110, 32, 87, 97, 104, 114, 115, 99, 104, 101, 105, 110, 108, 105, 99, 104, 107, 101, 105, 116, 32, 115, 112, 114, 105, 99, 104, 116, 32, 100, 97, 105, 252, 114, 44, 32, 100, 97, 223, 32, 100, 101, 114, 32, 82, 101, 105, 99, 104, 115, 116, 97, 103, 32, 114, 97, 100, 105, 107, 97, 108, 171, 32, 65, 99, 110, 100, 101, 114, 117, 110, 103, 101, 110, 32, 100, 97, 114, 97, 110, 32, 118, 101, 114, 115, 117, 99, 104, 101, 110, 32, 119, 105, 114, 100, 46, 57345, 69, 187, 32, 100, 252, 114, 115, 116, 101, 110, 32, 171, 171, 32, 100, 116, 171, 32, 118, 111, 114, 103, 101, 106, 99, 104, 108, 97, 65, 116, 110, 101, 171, 32, 83, 116, 101, 117, 101, 114, 110, 32, 72, 101, 105, 107, 101, 32, 76, 228, 109, 187, 57345, 57345, 71, 45, 119, 32, 187, 57345, 71, 101, 110, 101, 114, 97, 108, 45, 171, 44, 44, 44, 101, 108, 103, 101, 114, 32, 102, 252, 114, 32, 72, 97, 109, 98, 187, 114, 171, 45, 187, 114, 119, 171, 97, 32, 187, 111, 171, 32, 70, 114, 101, 116, 116, 97, 103, 46, 57345, 100, 101, 187, 32, 73, 86, 32, 171, 118, 114, 105, 108, 32, 187, 187, 73, 187, 57345, 101, 110, 116, 98, 114, 101, 110, 110, 101, 110, 46, 57345, 77, 97, 110, 99, 104, 101, 32, 100, 105, 114, 101, 107, 116, 101, 32, 83, 116, 101, 117, 101, 114, 32, 119, 105, 114, 100, 32, 118, 111, 110, 32, 100, 101, 110, 32, 80, 97, 114, 116, 101, 105, 101, 110, 32, 100, 101, 114, 32, 82, 101, 105, 99, 104, 83, 116, 97, 103, 83, 45, 77, 101, 104, 114, 104, 101, 114, 116, 32, 105, 110, 32, 100, 101, 109, 32, 83, 116, 101, 117, 101, 114, 98, 117, 107, 101, 116, 116, 32, 100, 101, 114, 109, 105, 103, 116, 32, 87, 101, 114, 100, 101, 114, 44, 32, 109, 97, 110, 99, 104, 101, 114, 32, 86, 101, 114, 98, 114, 97, 117, 99, 104, 32, 119, 105, 114, 100, 32, 105, 104, 110, 101, 110, 32, 105, 110, 100, 101, 109, 32, 111, 101, 112, 108, 97, 110, 108, 101, 110, 32, 77, 97, 223, 101, 32, 122, 117, 32, 117, 110, 114, 101, 99, 104, 116, 32, 104, 101, 114, 97, 110, 103, 101, 122, 111, 103, 101, 110, 32, 101, 114, 115, 99, 104, 101, 105, 110, 101, 110, 46, 57345, 68, 97, 223, 32, 100, 105, 101, 32, 68, 105, 115, 107, 117, 39, 102, 105, 111, 110, 32, 100, 101, 114, 32, 69, 116, 101, 117, 101, 114, 118, 111, 114, 108, 97, 103, 101, 110, 32, 102, 105, 99, 104, 32, 105, 109, 32, 90, 101, 105, 99, 104, 101, 110, 32, 100, 101, 83, 32, 66, 117, 114, 103, 102, 114, 105, 101, 100, 101, 110, 115, 32, 118, 111, 108, 108, 122, 105, 101, 104, 101, 110, 32, 119, 105, 114, 100, 44, 32, 105, 115, 116, 32, 97, 117, 115, 103, 101, 115, 99, 104, 108, 111, 102, 102, 101, 110, 46, 57345, 76, 117, 99, 104, 32, 97, 117, 115, 32, 101, 105, 110, 101, 32, 115, 101, 104, 114, 32, 114, 97, 115, 99, 104, 101, 32, 86, 101, 114, 97, 98, 115, 99, 104, 105, 101, 100, 117, 110, 103, 32, 100, 101, 114, 32, 83, 116, 101, 117, 101, 114, 32, 112, 108, 228, 110, 101, 32, 105, 115, 116, 32, 117, 110, 116, 101, 114, 32, 100, 105, 101, 115, 101, 110, 32, 85, 109, 115, 116, 228, 110, 100, 101, 110, 32, 107, 97, 117, 109, 32, 122, 117, 32, 114, 101, 99, 104, 110, 101, 110, 46, 57345, 68, 101, 114, 32, 82, 101, 105, 99, 104, 115, 116, 97, 103, 32, 119, 105, 114, 100, 32, 115, 105, 99, 104, 32, 119, 105, 101, 100, 101, 114, 32, 105, 110, 32, 100, 101, 114, 32, 82, 111, 108, 108, 101, 32, 100, 101, 115, 32, 77, 105, 116, 97, 114, 98, 101, 105, 116, 101, 114, 115, 32, 100, 99, 83, 32, 82, 101, 105, 99, 104, 45, 39, 99, 104, 97, 116, 122, 97, 109, 116, 101, 83, 32, 119, 32, 100, 101, 109, 32, 69, 114, 110, 110, 101, 32, 118, 101, 114, 115, 117, 99, 104, 101, 110, 187, 32, 100, 97, 223, 32, 101, 114, 32, 110, 101, 117, 101, 32, 83, 116, 101, 117, 101, 114, 110, 32, 98, 114, 105, 110, 103, 101, 110, 32, 119, 105, 114, 100, 46, 57345, 73, 110, 32, 100, 101, 114, 32, 66, 101, 103, 114, 101, 110, 122, 117, 110, 103, 44, 32, 105, 110, 32, 100, 101, 114, 32, 100, 97, 83, 32, 83, 116, 101, 117, 101, 114, 98, 117, 107, 101, 116, 116, 32, 100, 101, 109, 32, 82, 101, 105, 99, 104, 115, 116, 97, 103, 32, 112, 114, 228, 39, 101, 110, 116, 105, 101, 114, 116, 32, 119, 105, 114, 100, 44, 32, 119, 105, 114, 100, 32, 101, 83, 32, 105, 104, 110, 32, 97, 108, 32, 111, 32, 110, 105, 99, 107, 116, 32, 119, 105, 101, 100, 101, 114, 32, 118, 101, 114, 108, 97, 115, 115, 101, 110, 46, 57345, 68, 105, 101, 32, 103, 114, 246, 223, 116, 101, 110, 32, 66, 101, 100, 101, 110, 107, 101, 110, 32, 100, 252, 114, 115, 116, 101, 32, 100, 105, 101, 32, 85, 109, 115, 97, 116, 122, 115, 116, 101, 117, 101, 114, 32, 97, 117, 115, 108, 246, 115, 101, 110, 46, 57345, 86, 111, 110, 32, 100, 101, 110, 32, 97, 110, 100, 101, 114, 101, 110, 32, 83, 116, 101, 117, 101, 114, 110, 32, 97, 117, 102, 32, 100, 101, 110, 32, 86, 101, 114, 98, 114, 97, 117, 99, 104, 32, 104, 97, 116, 32, 118, 105, 101, 108, 108, 101, 105, 99, 104, 116, 32, 100, 105, 101, 32, 69, 114, 104, 246, 104, 117, 110, 103, 32, 100, 101, 115, 32, 83, 97, 102, 102, 101, 101, 122, 111, 108, 108, 101, 83, 32, 117, 110, 100, 32, 101, 116, 119, 97, 32, 97, 117, 99, 104, 32, 100, 105, 101, 32, 196, 105, 101, 114, 115, 116, 101, 117, 101, 114, 44, 32, 116, 114, 111, 116, 122, 100, 101, 109, 32, 115, 105, 101, 32, 105, 110, 32, 83, 101, 109, 32, 103, 101, 112, 108, 97, 110, 116, 101, 110, 32, 85, 109, 102, 97, 110, 103, 32, 110, 105, 99, 104, 116, 32, 117, 110, 109, 228, 223, 105, 103, 32, 105, 115, 116, 44, 32, 97, 108, 171, 32, 103, 101, 32, 102, 228, 104, 114, 100, 101, 116, 32, 122, 117, 32, 103, 101, 108, 116, 101, 110, 46, 57345, 90, 117, 32, 100, 101, 110, 32, 110, 101, 117, 101, 110, 32, 83, 116, 101, 117, 101, 114, 103, 101, 115, 101, 104, 101, 110, 32, 118, 101, 114, 108, 97, 117, 116, 101, 116, 32, 105, 109, 32, 82, 101, 105, 99, 104, 115, 32, 116, 97, 103, 32, 115, 99, 104, 111, 110, 32, 106, 101, 108, 63, 116, 46, 57345, 100, 97, 223, 32, 115, 105, 99, 104, 32, 107, 101, 114, 110, 101, 32, 77, 101, 104, 114, 104, 101, 105, 116, 32, 102, 252, 114, 32, 101, 105, 110, 101, 32, 97, 98, 101, 114, 109, 97, 108, 105, 103, 101, 32, 69, 114, 104, 246, 104, 117, 110, 103, 32, 100, 101, 114, 32, 80, 111, 115, 116, 45, 117, 110, 100, 32, 84, 101, 108, 101, 103, 114, 97, 112, 104, 101, 110, 32, 103, 101, 98, 252, 104, 114, 101, 110, 44, 32, 115, 111, 119, 105, 101, 32, 108, 252, 114, 32, 101, 105, 110, 101, 32, 66, 101, 115, 116, 101, 117, 101, 114, 117, 110, 103, 32, 118, 111, 110, 32, 82, 97, 102, 102, 101, 171, 44, 32, 84, 101, 101, 32, 117, 110, 100, 32, 97, 108, 107, 111, 104, 111, 108, 102, 114, 101, 105, 101, 110, 32, 71, 101, 116, 114, 228, 110, 107, 101, 110, 32, 102, 105, 110, 100, 101, 110, 32, 119, 252, 114, 100, 101, 46]}

But I will do some debugging now to find the potential sentence/input that is causing problems :)

Mindful commented 3 years ago

Hey Stefan,

So looking at your data, I think I might know what's wrong. The example input ids you gave looks fine, but the example input data you gave has a lot of really short lines. The masking methods exclude padding from possible characters to mask, but they also exclude [SEP] characters. to_examples.py puts a [SEP] token in between every sentence (every line of input), so if you have a batch that includes an example made up of a large number of these really short sentences, the number of maskable characters may be much smaller than the actual length of the example. That could definitely cause the issue you're seeing above.

I'm not sure exactly what your final training objective is here, but this seems like an undesirable situation even if the masking methods were fixed to accomodate it. I.E. I'm not sure if you want examples that are like 20% [SEP] characters, and look like:

Nr.[SEP]SI[SEP]Ausgabe ' -Lt[SEP]Freitag.[SEP]...

I can still look at fixing the masking methods to try and accommodate this if you'd like, but either using your own to_examples.py script (which should be easy - it's super simple) or changing how you split sentences so that you avoid these super short sentences might both fix your issue and improve data quality. It really depends on what your end goal is and how many of these super short sentences you have, though.

stefan-it commented 3 years ago

Thanks for that explanation :hugs:

I did some filtering (sentences < 5 tokens and sentences < 50 tokens are removed from the corpus). With sentences longer than 50 tokens, the pre-training crashes after 1824 steps 😟

Traceback (most recent call last):
  File "training/train.py", line 59, in <module>
    main()
  File "training/train.py", line 55, in main
    trainer.train(resume_from_checkpoint=checkpoint_dir)
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1246, in train
    for step, inputs in enumerate(epoch_iterator):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 518, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 558, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/mnt/shiba/training/masking.py", line 285, in __call__
    input_ids, labels, masked_indices = random_mask(padded_batch['input_ids'],
  File "/mnt/shiba/training/masking.py", line 55, in random_mask
    indices_to_mask = torch.stack(indices_to_mask)
RuntimeError: stack expects each tensor to be equal size, but got [251] at entry 0 and [234] at entry 10

:thinking:

Mindful commented 3 years ago

Alright, if filtering out <50 char sentences didn't fix the issue clearly I just need to look into it. I'll try and spend some time on it this weekend, although getting the exact batches that are giving you issues would still be a big help.

stefan-it commented 3 years ago

Hi @Mindful ,

sorry for the late reply. Here the tensors that have different shapes:

First tensor in batch:

  tensor([ 425,  490, 1109, 1268,  236,  958,  290,  708, 1291, 1545, 1200,  682,
         156, 1672,  769,  729, 1442,   12,  913, 1471, 1120, 1067,  883,  904,
        1394,  516,  932, 1662,  524,  915,  756,  579, 1097,  857,  320,   72,             
        1257, 1328,    9,  765, 1382,   26, 1646, 1095, 1468, 1023,  645,  824,                                                    
        1054,  994,  911, 1447, 1017,   81, 1636, 1414, 1217,  163,  629,   89,                      
        1036,  182,  461, 1392,  179,  261,  646,  247, 1435,  200,  743,   55,
         273, 1402,  387,  155, 1470, 1613,  921,  680, 1530,  829,  837,  479,
          48,  509,   24,  817, 1674,  477,  316, 1513,  914,   57, 1371, 1568,
         770,  863, 1252,  936,  668, 1326, 1295,  768, 1555,   73,  898,  286,
        1362,  243, 1510, 1600,  673,  162, 1310,  188,  698, 1255,  221,  889,
         213,  234,  543,  332, 1381, 1262, 1172,  925, 1018,    4, 1111,   44,
         984, 1472,   46, 1458,  901, 1156,  455, 1486,  596,  164, 1122,  385,
         661,  677,  300,  391, 1484,  309, 1116,   11, 1535,  755,  978,  317,
         653,    6,  503,  626, 1115,  776,  398, 1203, 1114,  510,  908,  121,
        1664,  171, 1575,   22, 1544, 1278,  372, 1162,  604, 1637, 1197,  265,
         912,  592,  758, 1138, 1398, 1567,   31,  709,  281, 1079, 1492,  734,
         103, 1171,  458, 1577,  284, 1676,  218, 1187,  573, 1578,  336,  403,
        1476, 1144,  498, 1489,  216,  393, 1350, 1009, 1621,  419,  198,  339,
         637,   45,  814, 1573, 1538, 1195,  622, 1375, 1356, 1305, 1592,  555,
         500, 1273, 1656,  369, 1264, 1571, 1316,  565,  296, 1673,   92, 1227,
        1499,  795,  804,  288,  266,  550,  235, 1427, 1373, 1415,  240])

last tensor in batch:

 tensor([ 54,  55,  87,  78,  66, 201,  46, 190, 219, 213, 211, 186, 102, 188,
        194, 134, 202, 234, 220, 103, 144,  95,  51,  73, 182,  83, 121, 174,
        216, 185, 124, 196, 159,  19, 133, 189,  74, 173,  12,  65, 163,  63,
        138, 225, 126, 184, 131, 214,  70,  98, 167,  47, 172,  67, 139,   6,       
        150,  94,   7,  21,  61, 169,  20, 222,  45, 122, 104, 158,  27, 224,
         64, 235,  30,  75, 143, 140, 132, 162,  28,  93,  38,  26,  13, 119,
          8,  40,  72, 151,   1, 137, 198,  17, 125, 114,  44,  81,  36, 146,
        123, 223, 130, 147, 204,  50, 161, 171,  80,  34, 116, 154, 100, 120,
        127, 197, 148,  99,  71, 209, 231,   4,  14,  60, 199,  69, 165,  68,
         57,  56,  89, 183, 195, 164, 193, 229, 207,  48, 118, 107, 217, 142,
         37, 203, 226, 181,  35, 108,  29,  88, 105,  79,  18,  84, 135, 153,
         23, 156, 168, 192, 111, 115, 178, 141, 117,  32, 166, 128, 170, 152,
        210,  90, 205,  10,  97, 113,   3, 129,  86, 191,  24, 230, 232, 157,
        112,  77,  52, 228,  25,  33,  49, 110,  62,  76,  59, 149, 101, 145,
         82, 212, 180, 179, 206, 160, 176, 155,  58,  39,  31, 200, 187, 177,
         91,  53, 136,  16, 175, 218, 106, 109,   2,  15, 215, 208,  43,  42,
          5,   9,  11, 221,  22,  85, 227,  41,  96, 233])

First tensor has a length of 251, last tensor 234. Thanks for your help!

stefan-it commented 3 years ago

Oh, I forgot where these tensors come from:

https://github.com/octanove/shiba/blob/54ddbf872286660b84f3434d003b34e662b7776d/training/masking.py#L55

(indices_to_mask before the torch.stack command)

Mindful commented 3 years ago

@stefan-it no worries on the delayed response, I was swamped last week too. In regards to the tensors, I have a pretty good idea of where the problem is/what code is generating the tensors with different shapes (as you say, it's the masking code).

What I was hoping to get from you is the input contents of the batch - either in tensors of token IDs or raw text - that causes this problem when it's given as input to the masking code. Basically I'm trying to get input that will allow me to reproduce the issue on my machine, which will make it much easier to debug. (obviously the full batch may be a lot of text, but worst case if you can just copy and paste the whole thing into a Gist that should work fine).

Mindful commented 2 years ago

Closing old issues.