onealwj / MVLT

PyTorch implementation of BMVC2022 paper Masked Vision-Language Transformers for Scene Text Recognition
Other
29 stars 6 forks source link

RandomWordMaskingGenerator #1

Open songkq opened 1 year ago

songkq commented 1 year ago

@onealwj Hi, I'm confusing why using len(label)+1 instead of len(label) here in the class AlignCollate when generating the random word mask?

word_mask = torch.cat([torch.from_numpy(self.word_masked_position_generator(
                len(label)+1)).unsqueeze(0) for label in labels], 0)

Why dose the mask_idc need to plus 1 here in the RandomWordMaskingGenerator? mask_idc = mask_idc + 1 Could you please give some kind advice?

songkq commented 1 year ago

Another issue in dataset.py

def hierarchical_dataset(root, opt, select_data='/', data_filtering_off=False, global_rank=0):
      ...
      # for dirpath, dirnames, filenames in os.walk(root+'/'):
      for dirpath, dirnames, filenames in os.walk(root+'/'+select_data[0]):
      ...
songkq commented 1 year ago

Another issue in models_mvlt.py, it doesn't work in PyTorch==1.8.1

t_embed = torch.where(
                    w_mask.unsqueeze(-1).expand(-1, -1, self.decoder_embed_dim), text_mask_tokens.float(), t_embed)

RuntimeError: expected scalar type float but found c10::Half

t_embed.float() works well.

t_embed = torch.where(
                    w_mask.unsqueeze(-1).expand(-1, -1, self.decoder_embed_dim), text_mask_tokens, t_embed.float())
onealwj commented 1 year ago

@songkq Using len(label)+1 in the class AlignCollate and mask_idc = mask_idc + 1 in the RandomWordMaskingGenerator, because we use the mask token as separate token between visual tokens and textual tokens. About dataset.py, I don't understand what your issue is?