Bobholamovic / CDLab

Yet another repository for developing and benchmarking deep learning-based change detection methods.
The Unlicense
196 stars 30 forks source link

关于加载训练集和验证集问题 #26

Open taojunhui opened 1 year ago

taojunhui commented 1 year ago

 if self.is_training:
            #构建数据集和优化器
            self.train_loader = data_factory(dataset, 'train', settings)
            self.eval_loader = data_factory(dataset, 'eval', settings)
            self.optimizer = optim_factory(optimizer, self.model, settings)

以上是训练器构建训练集和验证集的代码。但是,我从一下代码发现,在WHUDataset加载训练集和验证集时,是以subset参数来区分的,但是在调用时,都采用默认的值val。这样的话,训练集和验证集不都是一个了吗?请问是我疏忽那些细节了吗?

class WHUDataset(CDDataset):
    def __init__(
        self, 
        root, phase='train', 
        transforms=(None, None, None), 
        repeats=1,
        subset='val'
    ):
        super().__init__(root, phase, transforms, repeats, subset)

    def _read_file_paths(self):
        t1_list = sorted(glob(join(self.root, self.subset, 'A', '*.png')))
        t2_list = sorted(glob(join(self.root, self.subset, 'B', '*.png')))
        tar_list = sorted(glob(join(self.root, self.subset, 'label', '*.png')))
        assert len(t1_list) == len(t2_list) == len(tar_list)
        return t1_list, t2_list, tar_list

    def fetch_target(self, target_path):
        return (super().fetch_target(target_path)/255).astype(np.bool)
@DATA.register_func('WHU_train_dataset')
def build_whu_train_dataset(C):
    configs = get_common_train_configs(C)
    configs.update(dict(
        transforms=(Choose(
            HorizontalFlip(), VerticalFlip(), 
            Rotate('90'), Rotate('180'), Rotate('270'),
            Shift(), 
            Identity()), Normalize(np.asarray(C['mu']), np.asarray(C['sigma'])), None),
        root=constants.IMDB_WHU,
    ))

    from data.whu import WHUDataset
    return build_train_dataloader(WHUDataset, configs, C)

@DATA.register_func('WHU_eval_dataset')
def build_whu_eval_dataset(C):
    configs = get_common_eval_configs(C)
    configs.update(dict(
        transforms=(None, Normalize(np.asarray(C['mu']), np.asarray(C['sigma'])), None),
        root=constants.IMDB_WHU,
    ))

    from data.whu import WHUDataset
    return DataLoader(
        WHUDataset(**configs),
        batch_size=C['batch_size'],
        shuffle=False,
        num_workers=C['num_workers'],
        drop_last=False,
        pin_memory=C['device']!='cpu'
    )
def build_train_dataloader(cls, configs, C):
    return data.DataLoader(
        cls(**configs),
        batch_size=C['batch_size'],
        shuffle=True,
        num_workers=C['num_workers'],
        pin_memory=C['device']!='cpu',
        drop_last=True
    )
Bobholamovic commented 1 year ago

因为训练过程默认不需要使用到验证集,所以subset默认被设置为'val'。在评测模型精度时,可以设置subset参数为'test'来获取测试集上的精度指标。就像README中的这个例子:

python train.py eval --exp_config PATH_TO_CONFIG_FILE --resume PATH_TO_CHECKPOINT --save_on --subset test
taojunhui commented 1 year ago

我又看了一下代码,找到,感谢。

taojunhui commented 1 year ago

原来在这,进行配置

def get_common_train_configs(C):
    """获取通用的训练阶段配置"""
    return dict(phase='train', repeats=C['repeats'])

def get_common_eval_configs(C):
    """获取通用的评估阶段配置"""
    return dict(phase='eval', transforms=[None, None, None], subset=C['subset'])
class DatasetBase(data.Dataset, metaclass=ABCMeta):
    def __init__(
        self, 
        root, phase,
        transforms,
        repeats, 
        subset
    ):
        super().__init__()
        self.root = os.path.expanduser(root)
       .
       .
       .
        self.subset = 'train' if self.phase == 'train' else subset