megvii-research / TreeEnergyLoss

[CVPR2022] Tree Energy Loss: Towards Sparsely Annotated Semantic Segmentation
Other
101 stars 8 forks source link

请问PytorchEncoding/train.pth是什么,我没有找到对应的文件或说明 #9

Open YsasaYsasa opened 1 year ago

YsasaYsasa commented 1 year ago
    self.train_mask = torch.load(os.path.join(self.args.ori_root_dir, "PytorchEncoding/train.pth"))
    self.val_mask = torch.load(os.path.join(self.args.ori_root_dir, "PytorchEncoding/val.pth"))
YsasaYsasa commented 1 year ago

为什么数据准备阶段用的是Context数据集,测试代码又是2012,我找不到voc2012对应的.pth文件,要怎么进行数据预处理呢 self.train_mask = torch.load(os.path.join(self.args.ori_root_dir, "PytorchEncoding/train.pth")) self.val_mask = torch.load(os.path.join(self.args.ori_root_dir, "PytorchEncoding/val.pth"))

JackyWang2001 commented 1 year ago

For PASCAL_CONTEXT, the PytorchEncoding is found in the commented lines here.

For VOC2012, I created the encoding for myself. The segmentation masks and data split files are referred to this page. Here is how I created the encoding:

import os
import argparse
import PIL.Image as Image

import torch

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--split_dir', type=str, default=None, dest='split_dir')
    parser.add_argument('--mask_dir', type=str, default=None, dest='mask_dir',
                        help="The directory of VOC2012's SegmentationClassAug, containing 12031 pngs")
    parser.add_argument('--save_dir', type=str, default=None, dest='save_dir',
                        help="The directory to save the data.")
    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    train_ids_path = os.path.join(args.split_dir, 'train_aug.txt')
    trainval_ids_path = os.path.join(args.split_dir, 'trainval_aug.txt')

    with open(train_ids_path, 'r') as f:
        train_ids = [line[12: 23] for line in f.readlines()]

    with open(trainval_ids_path, 'r') as f:
        trainval_ids = [line[12: 23] for line in f.readlines()]

    val_ids = set(trainval_ids) - set(train_ids)
    val_ids = list(val_ids)

    train_mask = {
        k: Image.open(os.path.join(args.mask_dir, k + '.png')).convert('L') for k in train_ids
    }
    val_mask = {
        k: Image.open(os.path.join(args.mask_dir, k + '.png')).convert('L') for k in val_ids
    }

    torch.save(train_mask, os.path.join(args.save_dir, 'train.pth'))
    torch.save(val_mask, os.path.join(args.save_dir, 'val.pth'))
    print('Done!')