Open YsasaYsasa opened 1 year ago
为什么数据准备阶段用的是Context数据集,测试代码又是2012,我找不到voc2012对应的.pth文件,要怎么进行数据预处理呢 self.train_mask = torch.load(os.path.join(self.args.ori_root_dir, "PytorchEncoding/train.pth")) self.val_mask = torch.load(os.path.join(self.args.ori_root_dir, "PytorchEncoding/val.pth"))
For PASCAL_CONTEXT, the PytorchEncoding is found in the commented lines here.
For VOC2012, I created the encoding for myself. The segmentation masks and data split files are referred to this page. Here is how I created the encoding:
import os
import argparse
import PIL.Image as Image
import torch
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--split_dir', type=str, default=None, dest='split_dir')
parser.add_argument('--mask_dir', type=str, default=None, dest='mask_dir',
help="The directory of VOC2012's SegmentationClassAug, containing 12031 pngs")
parser.add_argument('--save_dir', type=str, default=None, dest='save_dir',
help="The directory to save the data.")
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
train_ids_path = os.path.join(args.split_dir, 'train_aug.txt')
trainval_ids_path = os.path.join(args.split_dir, 'trainval_aug.txt')
with open(train_ids_path, 'r') as f:
train_ids = [line[12: 23] for line in f.readlines()]
with open(trainval_ids_path, 'r') as f:
trainval_ids = [line[12: 23] for line in f.readlines()]
val_ids = set(trainval_ids) - set(train_ids)
val_ids = list(val_ids)
train_mask = {
k: Image.open(os.path.join(args.mask_dir, k + '.png')).convert('L') for k in train_ids
}
val_mask = {
k: Image.open(os.path.join(args.mask_dir, k + '.png')).convert('L') for k in val_ids
}
torch.save(train_mask, os.path.join(args.save_dir, 'train.pth'))
torch.save(val_mask, os.path.join(args.save_dir, 'val.pth'))
print('Done!')