lhoyer / DAFormer

[CVPR22] Official Implementation of DAFormer: Improving Network Architectures and Training Strategies for Domain-Adaptive Semantic Segmentation
Other
470 stars 93 forks source link

How to count the class statistics in Figure S1 of the supplementary material? #32

Closed qianyuzqy closed 2 years ago

qianyuzqy commented 2 years ago

Dear Lukas,

Thanks for your great work, DAFormer, and I am very interested in Figure S1 of the supplementary material, which depicts the class statistics of the corresponding dataset for 10k samples.

However, I have problems with modifying the mmseg/datasets/uda_dataset.py. Could you please tell me how to change the code to count the class statistics in Figure S1?

Thank you in advance! I have also read your excellent paper HRDA, I really hope it can be accepted to the ECCV 2022. Good luck to you!

Best Regards,

lhoyer commented 2 years ago

Dear Qianyu,

Thank you very much! Here is a code snippet similar to the one that I have used for the paper. I did some refactoring afterward, so some names/function definitions might not match. I tried to update those, but I might have missed something.

import json

import torch
from experiments import setup_rcs
from mmcv import Config
from tqdm import tqdm

from mmseg.apis import set_random_seed
from mmseg.datasets import build_dataloader, build_dataset

if __name__ == '__main__':
    N_SAMPLES_TOTAL = int(1e4)
    SOURCE = 'gta'
    TEMP = 0.01

    f = f'{configs/_base_/datasets/' \
        f'uda_{SOURCE}_to_cityscapes_512x512.py'
    cfg = {'_base_': [f]}
    cfg = setup_rcs(cfg, TEMP)

    cfg = json.dumps(cfg)
    cfg = Config.fromstring(cfg, '.json')
    cfg.seed = 0
    cfg.data.samples_per_gpu = 1
    set_random_seed(cfg.seed, deterministic=True)

    dataset = build_dataset(cfg.data.train)
    data_loader = build_dataloader(
        dataset,
        cfg.data.samples_per_gpu,
        cfg.data.workers_per_gpu,
        dist=False,
        seed=cfg.seed,
        drop_last=True)

    pixel_class_stats = 0
    n_samples = 0
    for batch in tqdm(data_loader, total=N_SAMPLES_TOTAL):
        gt = batch['gt_semantic_seg'].data[0].reshape(-1)
        gt = gt[gt != dataset.ignore_index]
        bincount = torch.bincount(
            gt, minlength=len(dataset.CLASSES)).long()
        pixel_class_stats += bincount
        n_samples += 1
        if n_samples >= N_SAMPLES_TOTAL:
            break

    out = {
        'name': f'rcm{TEMP}',
        'n_samples': n_samples,
        'pixel_class_stats': pixel_class_stats.tolist(),
    }
    print(f'rcm{RCM}_{SOURCE}_stats =', out)

Best regards, Lukas

qianyuzqy commented 2 years ago

Thank you very much for your kind help.

lhoyer commented 2 years ago

You're welcome!