yanxp / MetaR-CNN

Meta R-CNN : Towards General Solver for Instance-level Low-shot Learning
https://yanxp.github.io/metarcnn.html
177 stars 23 forks source link

Some bug about training #3

Closed lzhnb closed 4 years ago

lzhnb commented 4 years ago

I make everything ok and tun the train train_metarcnn.py

But i got error like this:

Traceback (most recent call last):
  File "train_metarcnn.py", line 220, in <module>
    [('2007', 'trainval')], metaclass, img_size, shots=shots, shuffle=True)
  File "/mnt/Disk1/zhliang/code/MetaR-CNN/lib/datasets/metadata.py", line 58, in __init__
    img = torch.from_numpy(np.array(prn_image[key][i]))
IndexError: list index out of range

located in here (datasets/metadata.py around line58):

        prn_image, prn_mask = self.get_prndata()
        for i in range(shots):
            cls = []
            data = []
            for n, key in enumerate(list(prn_image.keys())):
                img = torch.from_numpy(np.array(prn_image[key][i]))
                img = img.unsqueeze(0)
                mask = torch.from_numpy(np.array(prn_mask[key][i]))
                mask = mask.unsqueeze(0)
                mask = mask.unsqueeze(3)
                imgmask = torch.cat([img, mask], dim=3)
                data.append(imgmask.permute(0, 3, 1, 2).contiguous())
                cls.append(class_to_idx[key])
            self.prncls.append(cls)
            self.prndata.append(torch.cat(data,dim=0))

Because the train_metarcnn give the args.phase=1 it will give the MetaDataset initialization parameter-shots=200

And then I print the parameters:

        self.prndata = []
        self.prncls = []
        prn_image, prn_mask = self.get_prndata()
        print('shots: {}'.format(shots))
        for i in range(shots):
            cls = []
            data = []
            for n, key in enumerate(list(prn_image.keys())):
                print('prn_image_{}: {}'.format(key, len(prn_image[key])))
                img = torch.from_numpy(np.array(prn_image[key][i]))
                img = img.unsqueeze(0)
                mask = torch.from_numpy(np.array(prn_mask[key][i]))
                mask = mask.unsqueeze(0)
                mask = mask.unsqueeze(3)
                imgmask = torch.cat([img, mask], dim=3)
                data.append(imgmask.permute(0, 3, 1, 2).contiguous())
                cls.append(class_to_idx[key])
            self.prncls.append(cls)
            self.prndata.append(torch.cat(data,dim=0))

I find that

shots: 200
prn_image_bicycle: 196
prn_image_car: 600
prn_image_pottedplant: 145
prn_image_aeroplane: 232
prn_image_cat: 315
prn_image_person: 600
prn_image_boat: 160
prn_image_dog: 372
prn_image_horse: 220
prn_image_tvmonitor: 188
prn_image_bottle: 150
prn_image_diningtable: 100
prn_image_sheep: 94
prn_image_chair: 260
prn_image_train: 249

For different class, the number of its image is different and some of them less than 200, so it must be out of range.

In your papre the $D{meat}$ and $D{train}$ are exclude, so I guess you want to load some image into dataset and i change the code, if we need 200shots and there are 15 classes in $C_{meat}$ for VOC dataset, so we collect around 13 shots from each class. So i change the code and it runs.

        self.prndata = []
        self.prncls = []
        prn_image, prn_mask = self.get_prndata()
        cls = []
        data = []
        counts = shots/len(prn_image.keys())
        for n, key in enumerate(list(prn_image.keys())):
            for i in range(int(counts)):
                img = torch.from_numpy(np.array(prn_image[key][i]))
                img = img.unsqueeze(0)
                mask = torch.from_numpy(np.array(prn_mask[key][i]))
                mask = mask.unsqueeze(0)
                mask = mask.unsqueeze(3)
                imgmask = torch.cat([img, mask], dim=3)
                data.append(imgmask.permute(0, 3, 1, 2).contiguous())
                cls.append(class_to_idx[key])
        self.prncls.append(cls)
        self.prndata.append(torch.cat(data,dim=0))

Isn't it right? And i found the self.shots is three times larger than shots. Maybe i miss some detail of the paper, I will read it again.

Thanks

lzhnb commented 4 years ago

Suppose dataset is fine. When i train, i got the memory error like this

Loaded dataset `voc_2012_train_first_split` for training
Set proposal method: gt
Appending horizontally-flipped training examples...
wrote gt roidb to /mnt/Disk1/zhliang/code/MetaR-CNN/data/cache/voc_2012_train_first_split_gt_roidb.pkl
done
Preparing training data...
done
before filtering, there are 12330 images...
after filtering, there are 12330 images...
12330 roidb entries
Loading pretrained weights from data/pretrained_model/resnet101_caffe.pth
THCudaCheck FAIL file=/pytorch/torch/lib/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
  File "train_metarcnn.py", line 384, in <module>
    num_boxes_list)
  File "/home/lzhnb/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/mnt/Disk1/zhliang/code/MetaR-CNN/lib/model/faster_rcnn/faster_rcnn.py", line 61, in forward
    attentions = self.prn_network(prn_data)
  File "/mnt/Disk1/zhliang/code/MetaR-CNN/lib/model/faster_rcnn/resnet.py", line 328, in prn_network
    base_feat = self.RCNN_base(self.meta_conv1(im_data))
  ...
RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58

And this error happened in fp(lib/model/fasterrcnn/resnet.py)

  def prn_network(self,im_data):
    '''
    the Predictor-head Remodeling Network (PRN)
    :param im_data:
    :return attention vectors:
    '''
    base_feat = self.RCNN_base(self.meta_conv1(im_data))
    feature = self._head_to_tail(self.max_pooled(base_feat))
    attentions = self.sigmoid(feature)
    return  attentions

My server is 8700k+titan xp(12 G)2 and 16g2 with ubuntu 16.04 cuda8.0+cudn7.1.4 and the torch==0.3.1, torchvision==0.2.0

Is there any question in running train?

Thanks

yanxp commented 4 years ago

In the first phase, the metadata get the shots from the imgset voc2007 and voc2012, so the shots are more than 200. In the second, we set the self.shots = 3 * shots from the imgset voc2007. I have updated the code. As for problem of out of memory, you can reduce the batch_size.

mandal4 commented 4 years ago

In the first phase, the metadata get the shots from the imgset voc2007 and voc2012, so the shots are more than 200. In the second, we set the self.shots = 3 * shots from the imgset voc2007. I have updated the code. As for problem of out of memory, you can reduce the batch_size.

Could you explain the purpose of self.shots= 3*shots you set?