zamling / PSALM

[ECCV2024] This is an official implementation for "PSALM: Pixelwise SegmentAtion with Large Multi-Modal Model"
Apache License 2.0
183 stars 8 forks source link

Fine-tuning on customized dataset #20

Open ys-zong opened 4 days ago

ys-zong commented 4 days ago

Hi, thanks for the nice work! I'm trying to fine-tune PSALM on customized dataset, I wonder what specific modifications are needed (e.g., how should I organize data format, data loader etc.)? For example, if I want to further fine-tune on gRefcoco, can current gRefcoco evaluation dataset directly be used for fine-tuning? Many thanks!

zamling commented 4 days ago

Hi @ys-zong If you want to fine-tune on gRefcoco, you need to rebuild the data pipline:

  1. You can follow datasets/build_RefCOCO.py to prepare gRefCOCO train data. Each sample in this should contain a sentence and several mask annotations
  2. Following RefCOCO_dataset at here, build the gRefCOCO_dataset, it seems that you just need to change the IO of this dataset class
  3. train model following refcoco pipline, provide data_dict['dataset_type'] = 'referring_coco' can make model train in refcoco pipline
ys-zong commented 3 days ago

Thanks for the instructions. I tried it but I guess the main problem is how to deal with the "no-target" samples in gRefcoco - anns field will be an empty list. If directly use RefCOCO_dataset, there will be error here during preprocessing. If I use gRefcoco_Dataset (main difference seems to be in this line), it gives error during calculating loss:

[rank0]:   File "/mypath/PSALM/psalm/model/mask_decoder/Mask2Former_Simplify/utils/matcher.py", line 208, in forward
[rank0]:     return self.memory_efficient_forward(outputs, targets)
[rank0]:   File "/opt/conda/envs/psalm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/mypath/PSALM/psalm/model/mask_decoder/mask_criterion/pretrain_criterion.py", line 417, in memory_efficient_forward
[rank0]:     tgt_mask = targets[b]["masks"].to(out_mask)
[rank0]: TypeError: 'NoneType' object is not subscriptable

I wonder if the current model supports this? Similar to this, how did you handle/evaluate the no-target predictions during evaluation of gRefcoco? Did you evaluate the N-Acc defined in the the original paper? Thanks!

zamling commented 3 days ago

@ys-zong Yes, you are right. We do not deal with no-target situation (annotation is empty) when training. May be you need to rewrite the dataset mapper When we ft on gRefcoco, we add a all-zero mask in annotation. Here are my gRefCOCO dataset for training

class gRefCOCO_dataset(RefCOCO_dataset_pooling):
    def __getitem__(self, idx):
        mask_type = 'polygon'
        data = self.data[idx]
        image_file = data['image_info']['file_name']
        image_folder = self.data_args.refcoco_image_folder

        data_dict = {}
        data_dict['file_name'] = os.path.join(image_folder, image_file)
        data_dict['height'] = data['image_info']['height']
        data_dict['width'] = data['image_info']['width']
        data_dict['image_id'] = data['new_img_id']
        data_dict['annotations'] = data['anns']
        for annotation in data_dict['annotations']:
            annotation['bbox_mode'] = BoxMode.XYXY_ABS
            # annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
            if annotation['category_id'] in self.coco_id_to_cont_id:
                annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
            elif annotation['category_id'] in self.coco_id_to_cont_id.values():
                annotation['category_id'] = annotation['category_id']
            else:
                raise ValueError
            annotation['image_id'] = data['new_img_id']

        if len(data_dict['annotations']) == 0:
            # build negative mask
            annotation = {}
            annotation['bbox_mode'] = BoxMode.XYXY_ABS
            annotation['category_id'] = 1
            annotation['area'] = 0.0
            annotation['bbox'] = [0,0,0,0]

            binary_mask = np.zeros((data_dict['height'],data_dict['width'])).astype(np.uint8)
            mask_type = 'bitmask'

            segmentation = encode(np.asfortranarray(binary_mask))
            segmentation = {
                'counts': segmentation['counts'].decode('ascii'),
                'size': segmentation['size'],
            }
            annotation['segmentation'] = segmentation

            data_dict['annotations'] = [annotation]

        if isinstance(self.data_args.image_processor, dict):
            processor = self.data_args.image_processor['instance']
        else:
            processor = self.data_args.image_processor
        data_dict = processor.preprocess(data_dict, mask_format=mask_type)
        # instruction = data['instruction']
        sentences = data['instruction']
        # prefix_inst = 'Referring Segmentation according to the following instruction:'
        prefix_inst = 'This is an image <image>, Please doing Referring Segmentation according to the following instruction:'
        instruction = ''
        for sent in sentences:
            instruction += ' {}.'.format(sent['sent'])
        # instruction = 'Please segment all the items in this image'

        # num_class = len(self.coco_class_name)
        # category = '<cls>, ' * (num_class-1) + '<cls>.'

        if self.data_args.seg_last:
            sources = [[{'from': 'human', 'value': prefix_inst + '\n<refer>'},
                        {'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
        else:
            sources = [[{'from': 'human', 'value': prefix_inst + '\n<refer>'},
                        {'from': 'gpt', 'value': '\nSure, the segmentation result is'}]]
        # sources = self.preprocess_multimodal(copy.deepcopy(sources))

        text_dict = self.preprocess_llama2(sources, self.tokenizer)
        input_ids = text_dict['input_ids'][0]
        labels = text_dict['labels'][0]

        token_refer_id = self.preprocess_referring_instruction(instruction)
        refer_embedding_indices = torch.zeros_like(input_ids)
        refer_embedding_indices[input_ids == REFER_TOKEN_INDEX] = 1

        data_dict['input_ids'] = text_dict['input_ids'][0]
        data_dict['labels'] = text_dict['labels'][0]
        data_dict['dataset_type'] = 'referring_coco'

        data_dict['token_refer_id'] = token_refer_id
        data_dict['refer_embedding_indices'] = refer_embedding_indices
        return data_dict

You can try this

zamling commented 3 days ago

For evaluation on gRefCOCO, as mentioned in paper, we set a threshold, if there is not any mask >=0.6, it will be 'no-target' sample. The results are image

ys-zong commented 2 days ago

Thanks! I used this dataset class and it works well. I wonder what are the hyper-parameters you used for training (batch size, learning rate, epochs, etc.)? I did get improvement after fine-tuning but it's around 60 cIoU and 65 N-acc, which is much lower than yours.