mit-han-lab / temporal-shift-module

[ICCV 2019] TSM: Temporal Shift Module for Efficient Video Understanding
https://arxiv.org/abs/1811.08383
MIT License
2.05k stars 418 forks source link

训练ucf101时__getitem__函数得返回值跟dataloader的返回值不一致 #213

Closed dengfenglai321 closed 2 years ago

dengfenglai321 commented 2 years ago

hi, 我训练ucf101时发现: 训练ucf101时getitem函数得返回值跟dataloader的返回值不一致

  1. 打印getitem函数的返回值如下: ` def getitem(self, index): record = self.video_list[index]

    check this is a legit video folder

    if self.image_tmpl == 'flow_{}_{:05d}.jpg':
        file_name = self.image_tmpl.format('x', 1)
        full_path = os.path.join(self.root_path, record.path, file_name)
    elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
        file_name = self.image_tmpl.format(int(record.path), 'x', 1)
        full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
    else:
        file_name = self.image_tmpl.format(1)
        full_path = os.path.join(self.root_path, record.path, file_name)
    
    while not os.path.exists(full_path):
        print('################## Not Found:', os.path.join(self.root_path, record.path, file_name))
        index = np.random.randint(len(self.video_list))
        record = self.video_list[index]
        if self.image_tmpl == 'flow_{}_{:05d}.jpg':
            file_name = self.image_tmpl.format('x', 1)
            full_path = os.path.join(self.root_path, record.path, file_name)
        elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
            file_name = self.image_tmpl.format(int(record.path), 'x', 1)
            full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
        else:
            file_name = self.image_tmpl.format(1)
            full_path = os.path.join(self.root_path, record.path, file_name)
    
    if not self.test_mode:
        segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
    else:
        segment_indices = self._get_test_indices(record)
    # print('record : {}'.format(record))
    print('segment_indices : {}'.format(segment_indices))
    data, label = self.get(record, segment_indices)
    print('data : {}'.format(data.size()))
    print('label : {}'.format(label)) 
    return self.get(record, segment_indices)`
  2. 打印dataloader返回值如下:

` for i, (input, target) in enumerate(train_loader):

measure data loading time

    data_time.update(time.time() - end)
    target = target.cuda()
    input_var = torch.autograd.Variable(input)
    target_var = torch.autograd.Variable(target)
    print('\n')
    print('input_var : {}'.format(input_var.size()))
    print('target_var : {}'.format(target_var))`
  1. 实际结果如下 企业微信截图_16431896458396 我检查了很多次, 发现两者的返回都是对不上的,请问可能时什么原因呢