gidariss / FewShotWithoutForgetting

MIT License
518 stars 110 forks source link

Could you please give me some advices for improving acc_both? #14

Open asdfqwer2015 opened 5 years ago

asdfqwer2015 commented 5 years ago

Hi, gidariss: Thanks for your shared code. I tested this mechanism for my own feature model and dataset. And got high acc_base and acc_novel except acc_both, could you please give me some advices?

I trained a model on my own dataset(trainset samples > 110k) and got acc_base 91.22%, acc_novel 90.48% and acc_both 74.97% in stage1. And the model is training in stage2 now. But I found the acc_both is still very low. And should I use the feature net with best acc_both instead of best acc_novel from stage1 in stage2?

Thanks.

EmmanouelP commented 5 years ago

@asdfqwer2015 Did you use your own implementation for the method? If not, could please share what changes should be made in order to use this code on own dataset? (if it's not much of a trouble of course)

asdfqwer2015 commented 5 years ago

Thanks for your reply and sorry for my late response, my network connection is not very well. I modified feature net and dataloader, and slightly modified FewShot due to newaxis frames. En, I tested this few shot mechanism for train gesture classification based on a 3Dconv feature net from https://github.com/ahmetgunduz/Real-time-GesRec and the datasets is 20bn-jester, which has 27 classes.

For feat net, I almost copied the code from here: https://github.com/ahmetgunduz/Real-time-GesRec/blob/4aaf03c5a6569a2d385839545eca01aca35011f6/model.py#L41-L47 , and nFeat is 2048

For cls net, I used cls net from FewShotWithoutForgetting framework, but the nFeat is 2048(feat net).

For dataloader, the code is below:

class Jester(data.Dataset):
    def __init__(self, phase='train', novel_labelidx=[], sample_duration=16, sample_size=112, half_fps=False,
                 do_not_use_random_transf=False):
        n_samples_for_each_video = 1
        self.modality = 'RGB'
        train_initial_scale = 1.0
        n_train_scales = 5
        scale_step = 0.84089641525
        self.sample_duration = sample_duration
        self.sample_size = sample_size
        self.half_fps = half_fps
        root_path = '../datasets/20bn-jester/20bn-jester-v1'
        annotation_path = '../annotation_Jester/jester.json'
        self.loader = jester_video_loader()

        self.base_folder = 'jester'
        assert(phase=='train' or phase=='val' or phase=='test')
        self.phase = phase
        self.name = 'Jester_' + phase

        print('Loading Jester dataset - phase {0}'.format(phase))

        if self.phase=='train':
            # During training phase we only load the training phase images
            # of the training categories (aka base categories).
            self.data, self.class_names = jester_make_dataset(root_path, annotation_path, 'training',
                                                              n_samples_for_each_video, sample_duration)
            self.data, self.class_names, novel_labelidx = jester_trans_labelidx(self.data, self.class_names,
                                                                                novel_labelidx)
            self.data = list(filter(lambda x: x['label'] not in novel_labelidx, self.data))
            self.class_names = {k: v for k, v in self.class_names.items() if k not in novel_labelidx}
            self.labels = list(map(lambda x: x['label'], self.data))

            self.label2ind = buildLabelIndex(self.labels)
            self.labelIds = sorted(self.label2ind.keys())
            self.num_cats = len(self.labelIds)
            self.labelIds_base = self.labelIds
            self.num_cats_base = len(self.labelIds_base)

        elif self.phase=='val' or self.phase=='test':
            if self.phase=='test':
                # load data that will be used for evaluating the recognition
                # accuracy of the base categories.
                data, class_names = jester_make_dataset(root_path, annotation_path, 'validation',
                                                        n_samples_for_each_video, sample_duration)
                data, class_names, novel_labelidx = jester_trans_labelidx(data, class_names, novel_labelidx)
                data_base = list(filter(lambda x: x['label'] not in novel_labelidx, data))
                # load data that will be use for evaluating the few-shot recogniton
                # accuracy on the novel categories.
                data_novel = list(filter(lambda x: x['label'] in novel_labelidx, data))
            else: # phase=='val'
                # load data that will be used for evaluating the recognition
                # accuracy of the base categories.
                data, class_names = jester_make_dataset(root_path, annotation_path, 'validation',
                                                        n_samples_for_each_video, sample_duration)
                data, class_names, novel_labelidx = jester_trans_labelidx(data, class_names, novel_labelidx)
                data_base = list(filter(lambda x: x['label'] not in novel_labelidx, data))
                # load data that will be use for evaluating the few-shot recogniton
                # accuracy on the novel categories.
                data_novel = list(filter(lambda x: x['label'] in novel_labelidx, data))

            self.dahow to paste code in ta = data_base + data_novel
            self.labels = list(map(lambda x: x['label'], self.data))
            labels_base = list(map(lambda x: x['label'], data_base))
            labels_novel = list(map(lambda x: x['label'], data_novel))
            self.class_names = class_names

            self.label2ind = buildLabelIndex(self.labels)
            self.labelIds = sorted(self.label2ind.keys())
            self.num_cats = len(self.labelIds)

            # self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
            # self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
            self.labelIds_base = buildLabelIndex(labels_base).keys()
            self.labelIds_novel = buildLabelIndex(labels_novel).keys()
            self.num_cats_base = len(self.labelIds_base)
            self.num_cats_novel = len(self.labelIds_novel)
            intersection = set(self.labelIds_base) & set(self.labelIds_novel)
            assert(len(intersection) == 0)
        else:
            raise ValueError('Not valid phase {0}'.format(self.phase))

        if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
            self.spatial_transform = Compose([
                Scale(self.sample_size),
                CenterCrop(self.sample_size),
                ToTensor(255), Normalize([0, 0, 0], [1, 1, 1])
            ])
            self.temporal_transform = Compose([
                TemporalRandomHalfFpsCrop(),
                TemporalCenterCrop(self.sample_duration)
            ])
            self.target_transform = None
        else:
            scales = [train_initial_scale]
            for i in range(1, n_train_scales):
                scales.append(scales[-1] * scale_step)
            self.spatial_transform = Compose([
                MultiScaleRandomCrop(scales, self.sample_size),
                # SpatialElasticDisplacement(100*2., 100*0.08),
                SpatialElasticDisplacement(),
                ToTensor(255), Normalize([0, 0, 0], [1, 1, 1])
            ])
            self.temporal_transform = Compose([
                TemporalRandomHalfFpsCrop(),
                TemporalRandomCrop(self.sample_duration)
            ])
            self.target_transform = None

    def __getitem__(self, index):
        path = self.data[index]['video']

        frame_indices = self.data[index]['frame_indices']
        if self.temporal_transform is not None:
            frame_indices = self.temporal_transform(frame_indices)
        clip = self.loader(path, frame_indices, self.modality, self.sample_duration)

        oversample_clip =[]
        if self.spatial_transform is not None:
            self.spatial_transform.randomize_parameters()
            clip = [self.spatial_transform(img) for img in clip]

        im_dim = clip[0].size()[-2:]
        clip = torch.cat(clip, 0).view((self.sample_duration, -1) + im_dim).permute(1, 0, 2, 3)

        target = self.labels[index]
        if self.target_transform is not None:
            target = self.target_transform(target)

        return clip, target

    def __len__(self):
        return len(self.data)

For FewShot.py, I just modified these lines for frames axis: https://github.com/gidariss/FewShotWithoutForgetting/blob/0efdc78ceddf2395b5b0bef4f789f9d63c761b9b/algorithms/FewShot.py#L126-L129 https://github.com/gidariss/FewShotWithoutForgetting/blob/0efdc78ceddf2395b5b0bef4f789f9d63c761b9b/algorithms/FewShot.py#L198-L205

asdfqwer2015 commented 5 years ago

En, 5 classes of gestures are chosen as novel classes. In stage2, AccuracyBase and AccuracyNovel have not increased, and I think AccuracyNovel has already very high. But AccuracyBoth dropped from 75.x% to 59.x%. Could you please give me some advice to handle this dataset? Thanks. :)

EmmanouelP commented 5 years ago

@asdfqwer2015 A million thanks for your reply. I will try and test this on my own dataset or the one you pointed and will give you some feedback as soon as i can.

asdfqwer2015 commented 5 years ago

Thanks for your help. Waiting for your advice. :)