valeoai / xmuda

Cross-Modal Unsupervised Domain Adaptationfor 3D Semantic Segmentation
Other
192 stars 36 forks source link

The code of Vanilla Fusion or xMUDA Fusion? #5

Closed weiliuxm closed 4 years ago

weiliuxm commented 4 years ago

Thank you for sharing the code. It is very great! Do you have any plan to release the code of Vanilla Fusion or xMUDA Fusion?

maxjaritz commented 4 years ago

Hi, Thanks! We currently do not plan to release the fusion part as it is not part of the main experiment and would make the repo more complex. If you would like to implement it yourself, you only need to minimally change the architecture (concat features from 2D and 3D) as explained in the paper.

Let me know if you need more details.

Best, Max

weiliuxm commented 4 years ago

Thank you for your reply.

taeyeopl commented 3 years ago

Thanks for sharing good work. I have a simple question related to the fusion loss terms.

Q1. Did you used additional segmentation and kl_div loss for the P_fuse same with 2D, 3D seg_loss?? When doing fusion experiments (Figure 4 in the main paper, Table 2, Table 3), Can you explain which loss option was applied??

maxjaritz commented 3 years ago

It is the following:

Please have a look at this fusion training script (it might need some adjustments to work with the published code):

#!/usr/bin/env python
import os
import os.path as osp
import argparse
import logging
import time
import socket
import warnings

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from xmuda.common.solver.build import build_optimizer, build_scheduler
from xmuda.common.utils.checkpoint import CheckpointerV2
from xmuda.common.utils.logger import setup_logger
from xmuda.common.utils.metric_logger import MetricLogger
from xmuda.common.utils.torch_util import set_random_seed
from xmuda.models.build import build_model_fuse
from xmuda.data.build import build_dataloader
from xmuda.data.utils.validate import validate
from xmuda.models.losses import entropy_loss

def parse_args():
    parser = argparse.ArgumentParser(description='xMUDA training')
    parser.add_argument(
        '--cfg',
        dest='config_file',
        default='',
        metavar='FILE',
        help='path to config file',
        type=str,
    )
    parser.add_argument(
        'opts',
        help='Modify config options using the command-line',
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    return args

def init_metric_logger(metric_list):
    new_metric_list = []
    for metric in metric_list:
        if isinstance(metric, (list, tuple)):
            new_metric_list.extend(metric)
        else:
            new_metric_list.append(metric)
    metric_logger = MetricLogger(delimiter='  ')
    metric_logger.add_meters(new_metric_list)
    return metric_logger

def train(cfg, output_dir='', run_name=''):
    # ---------------------------------------------------------------------------- #
    # Build models, optimizer, scheduler, checkpointer, etc.
    # ---------------------------------------------------------------------------- #
    logger = logging.getLogger('xmuda.train')

    set_random_seed(cfg.RNG_SEED)

    # build fuse model
    model_fuse, train_metric_fuse = build_model_fuse(cfg)
    logger.info('Build fuse model:\n{}'.format(str(model_fuse)))
    num_params = sum(param.numel() for param in model_fuse.parameters())
    print('#Parameters: {:.2e}'.format(num_params))

    model_fuse = model_fuse.cuda()

    # build optimizer
    optimizer_fuse = build_optimizer(cfg, model_fuse)

    # build lr scheduler
    scheduler_fuse = build_scheduler(cfg, optimizer_fuse)

    # build checkpointer
    # Note that checkpointer will load state_dict of model, optimizer and scheduler.
    checkpointer_fuse = CheckpointerV2(model_fuse,
                                     optimizer=optimizer_fuse,
                                     scheduler=scheduler_fuse,
                                     save_dir=output_dir,
                                     logger=logger,
                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)
    checkpoint_data_fuse = checkpointer_fuse.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)
    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD

    # build tensorboard logger (optionally by comment)
    if output_dir:
        tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))
        summary_writer = SummaryWriter(tb_dir)
    else:
        summary_writer = None

    # ---------------------------------------------------------------------------- #
    # Train
    # ---------------------------------------------------------------------------- #
    max_iteration = cfg.SCHEDULER.MAX_ITERATION
    start_iteration = checkpoint_data_fuse.get('iteration', 0)

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)
    train_dataloader_src = build_dataloader(cfg, mode='train', domain='source', start_iteration=start_iteration)
    train_dataloader_trg = build_dataloader(cfg, mode='train', domain='target', start_iteration=start_iteration)
    val_period = cfg.VAL.PERIOD
    val_dataloader = build_dataloader(cfg, mode='val', domain='target') if val_period > 0 else None

    best_metric_name = 'best_{}'.format(cfg.VAL.METRIC)
    best_metric = {
        '2d': checkpoint_data_fuse.get(best_metric_name, None),
    }
    best_metric_iter = {'2d': -1}
    logger.info('Start training from iteration {}'.format(start_iteration))

    # add metrics
    train_metric_logger = init_metric_logger([train_metric_fuse])
    val_metric_logger = MetricLogger(delimiter='  ')

    def setup_train():
        # set training mode
        model_fuse.train()
        # reset metric
        train_metric_logger.reset()

    def setup_validate():
        # set evaluate mode
        model_fuse.eval()
        # reset metric
        val_metric_logger.reset()

    if cfg.TRAIN.CLASS_WEIGHTS:
        class_weights = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS).cuda()
    else:
        class_weights = None

    if cfg.TRAIN.CLASS_WEIGHTS_PL:
        class_weights_pl = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS_PL).cuda()
    else:
        class_weights_pl = None

    setup_train()
    end = time.time()
    train_iter_src = enumerate(train_dataloader_src)
    train_iter_trg = enumerate(train_dataloader_trg)
    for iteration in range(start_iteration, max_iteration):
        # fetch data_batches for source & target
        _, data_batch_src = train_iter_src.__next__()
        _, data_batch_trg = train_iter_trg.__next__()
        data_time = time.time() - end
        # copy data from cpu to gpu
        if 'SCN' in cfg.DATASET_SOURCE.TYPE and 'SCN' in cfg.DATASET_TARGET.TYPE:
            # source
            data_batch_src['x'][1] = data_batch_src['x'][1].cuda()
            data_batch_src['seg_label'] = data_batch_src['seg_label'].cuda()
            data_batch_src['img'] = data_batch_src['img'].cuda()
            # target
            data_batch_trg['x'][1] = data_batch_trg['x'][1].cuda()
            # data_batch_trg['seg_label'] = data_batch_trg['seg_label'].cuda()
            data_batch_trg['img'] = data_batch_trg['img'].cuda()
            if cfg.TRAIN.XMUDA.lambda_pl > 0:
                data_batch_trg['pseudo_label_2d'] = data_batch_trg['pseudo_label_2d'].cuda()
                # data_batch_trg['pseudo_label_3d'] = data_batch_trg['pseudo_label_3d'].cuda()
        else:
            raise NotImplementedError('Only SCN is supported for now.')

        optimizer_fuse.zero_grad()

        # ---------------------------------------------------------------------------- #
        # Train on source
        # ---------------------------------------------------------------------------- #

        preds_fuse = model_fuse(data_batch_src)

        # segmentation loss: cross entropy
        seg_loss_src_fuse = F.cross_entropy(preds_fuse['seg_logit'], data_batch_src['seg_label'], weight=class_weights)
        train_metric_logger.update(seg_loss_src_fuse=seg_loss_src_fuse)
        loss_fuse = seg_loss_src_fuse

        if cfg.TRAIN.XMUDA.lambda_xm_src > 0:
            # cross-modal loss: KL divergence
            seg_logit_2d = preds_fuse['seg_logit_2d']
            seg_logit_3d = preds_fuse['seg_logit_3d']
            xm_loss_src_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),
                                      F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
                                      reduction='none').sum(1).mean()
            xm_loss_src_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),
                                      F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
                                      reduction='none').sum(1).mean()
            train_metric_logger.update(xm_loss_src_2d=xm_loss_src_2d,
                                       xm_loss_src_3d=xm_loss_src_3d)
            loss_fuse += cfg.TRAIN.XMUDA.lambda_xm_src * (xm_loss_src_2d + xm_loss_src_3d)

        # update metric (e.g. IoU)
        with torch.no_grad():
            train_metric_fuse.update_dict(preds_fuse, data_batch_src)

        # backward
        loss_fuse.backward()

        # ---------------------------------------------------------------------------- #
        # Train on target
        # ---------------------------------------------------------------------------- #

        preds_fuse = model_fuse(data_batch_trg)

        loss_fuse = []
        if cfg.TRAIN.XMUDA.lambda_xm_trg > 0:
            # cross-modal loss: KL divergence
            seg_logit_2d = preds_fuse['seg_logit_2d']
            seg_logit_3d = preds_fuse['seg_logit_3d']
            xm_loss_trg_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),
                                      F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
                                      reduction='none').sum(1).mean()
            xm_loss_trg_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),
                                      F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
                                      reduction='none').sum(1).mean()
            train_metric_logger.update(xm_loss_trg_2d=xm_loss_trg_2d,
                                       xm_loss_trg_3d=xm_loss_trg_3d)
            loss_fuse.append(cfg.TRAIN.XMUDA.lambda_xm_trg * xm_loss_trg_2d)
            loss_fuse.append(cfg.TRAIN.XMUDA.lambda_xm_trg * xm_loss_trg_3d)
        if cfg.TRAIN.XMUDA.lambda_pl > 0:
            # self-training loss with pseudo labels
            # Note that the fused labels must be stored in 'pseudo_label_2d'
            pl_loss_trg_fuse = F.cross_entropy(preds_fuse['seg_logit'], data_batch_trg['pseudo_label_2d'],
                                               weight=class_weights_pl)
            train_metric_logger.update(pl_loss_trg_fuse=pl_loss_trg_fuse)
            loss_fuse.append(cfg.TRAIN.XMUDA.lambda_pl * pl_loss_trg_fuse)
        if cfg.TRAIN.XMUDA.lambda_minent > 0:
            # MinEnt
            minent_loss_trg_fuse = entropy_loss(F.softmax(preds_fuse['seg_logit'], dim=1))
            train_metric_logger.update(minent_loss_trg_2d=minent_loss_trg_fuse)
            loss_fuse.append(cfg.TRAIN.XMUDA.lambda_minent * minent_loss_trg_fuse)

        sum(loss_fuse).backward()

        optimizer_fuse.step()

        batch_time = time.time() - end
        train_metric_logger.update(time=batch_time, data=data_time)

        # log
        cur_iter = iteration + 1
        if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD == 0):
            logger.info(
                train_metric_logger.delimiter.join(
                    [
                        'iter: {iter:4d}',
                        '{meters}',
                        'lr: {lr:.2e}',
                        'max mem: {memory:.0f}',
                    ]
                ).format(
                    iter=cur_iter,
                    meters=str(train_metric_logger),
                    lr=optimizer_fuse.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
                )
            )

        # summary
        if summary_writer is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:
            keywords = ('loss', 'acc', 'iou')
            for name, meter in train_metric_logger.meters.items():
                if all(k not in name for k in keywords):
                    continue
                summary_writer.add_scalar('train/' + name, meter.avg, global_step=cur_iter)

        # checkpoint
        if (ckpt_period > 0 and cur_iter % ckpt_period == 0) or cur_iter == max_iteration:
            checkpoint_data_fuse['iteration'] = cur_iter
            checkpoint_data_fuse[best_metric_name] = best_metric['2d']
            checkpointer_fuse.save('model_fuse_{:06d}'.format(cur_iter), **checkpoint_data_fuse)

        # ---------------------------------------------------------------------------- #
        # validate for one epoch
        # ---------------------------------------------------------------------------- #
        if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):
            start_time_val = time.time()
            setup_validate()

            validate(cfg,
                     model_fuse,
                     None,
                     val_dataloader,
                     val_metric_logger)

            epoch_time_val = time.time() - start_time_val
            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(
                cur_iter, val_metric_logger.summary_str, epoch_time_val))

            # summary
            if summary_writer is not None:
                keywords = ('loss', 'acc', 'iou')
                for name, meter in val_metric_logger.meters.items():
                    if all(k not in name for k in keywords):
                        continue
                    summary_writer.add_scalar('val/' + name, meter.avg, global_step=cur_iter)

            # best validation
            for modality in ['2d']:
                cur_metric_name = cfg.VAL.METRIC + '_' + modality
                if cur_metric_name in val_metric_logger.meters:
                    cur_metric = val_metric_logger.meters[cur_metric_name].global_avg
                    if best_metric[modality] is None or best_metric[modality] < cur_metric:
                        best_metric[modality] = cur_metric
                        best_metric_iter[modality] = cur_iter

            # restore training
            setup_train()

        scheduler_fuse.step()
        end = time.time()

    for modality in ['2d']:
        logger.info('Best val-{}-{} = {:.2f} at iteration {}'.format(modality.upper(),
                                                                     cfg.VAL.METRIC,
                                                                     best_metric[modality] * 100,
                                                                     best_metric_iter[modality]))

def main():
    args = parse_args()

    # load the configuration
    # import on-the-fly to avoid overwriting cfg
    from xmuda.common.config import purge_cfg
    from xmuda.config.xmuda import cfg
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    purge_cfg(cfg)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    # replace '@' with config path
    if output_dir:
        config_path = osp.splitext(args.config_file)[0]
        output_dir = output_dir.replace('@', config_path.replace('configs/', ''))
        if osp.isdir(output_dir):
            warnings.warn('Output directory exists.')
        os.makedirs(output_dir, exist_ok=True)

    # run name
    timestamp = time.strftime('%m-%d_%H-%M-%S')
    hostname = socket.gethostname()
    run_name = '{:s}.{:s}'.format(timestamp, hostname)

    logger = setup_logger('xmuda', output_dir, comment='train.{:s}'.format(run_name))
    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))
    logger.info(args)

    logger.info('Loaded configuration file {:s}'.format(args.config_file))
    logger.info('Running with config:\n{}'.format(cfg))

    # in fusion, dual head is necessary to apply cross-modal loss
    assert cfg.MODEL_2D.DUAL_HEAD == cfg.MODEL_3D.DUAL_HEAD
    # check if there is at least one loss on target set
    assert cfg.TRAIN.XMUDA.lambda_xm_src > 0 or cfg.TRAIN.XMUDA.lambda_xm_trg > 0 or cfg.TRAIN.XMUDA.lambda_pl > 0 or \
           cfg.TRAIN.XMUDA.lambda_minent > 0
    train(cfg, output_dir, run_name)

if __name__ == '__main__':
    main()
taeyeopl commented 3 years ago

Thanks for your reply with the code.

taeyeopl commented 3 years ago

I observed something different from my expectation. Based on the below figure and equations 2 and 3 (main paper), my understanding is that the code and figure arrow direction are not matched. Q1. Can you explain this part?? Have I missed something?

seg_logit_2d = preds_fuse['seg_logit_2d']
seg_logit_3d = preds_fuse['seg_logit_3d']

# My understanding based on the figure
  xm_loss_src_2d = F.kl_div(F.log_softmax(preds_fuse['seg_logit'], dim=1),
                            F.softmax(seg_logit_2d.detach(), dim=1),
                            reduction='none').sum(1).mean()
  xm_loss_src_3d = F.kl_div(F.log_softmax(preds_fuse['seg_logit'], dim=1),
                            F.softmax(seg_logit_3d.detach(), dim=1),
                            reduction='none').sum(1).mean()
# Your provided code              
      xm_loss_src_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),
                                F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
                                reduction='none').sum(1).mean()
      xm_loss_src_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),
                                F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
                                reduction='none').sum(1).mean()              

image

maxjaritz commented 3 years ago

Hi, The D_KL annotation might be confusing as it is defined D_KL(P||Q) where P is the target. It is worth it to have a look at the definition of Kullback-Leibler divergence: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence

The following explains it:

Consider two probability distributions P and Q. Usually, P represents the data, the observations, or a measured probability distribution. Distribution Q represents instead a theory, a model, a description or an approximation of P.

This means that P is the target, in our case P_fuse. The arrow direction models the flow of information, i.e. P_fuse teaches P_3d->fuse and P_2D->fuse.

Hope that helps.

taeyeopl commented 3 years ago

Thanks, I clearly understood KL loss.