Closed weiliuxm closed 4 years ago
Hi, Thanks! We currently do not plan to release the fusion part as it is not part of the main experiment and would make the repo more complex. If you would like to implement it yourself, you only need to minimally change the architecture (concat features from 2D and 3D) as explained in the paper.
Let me know if you need more details.
Best, Max
Thank you for your reply.
Thanks for sharing good work. I have a simple question related to the fusion loss terms.
Q1. Did you used additional segmentation and kl_div loss for the P_fuse same with 2D, 3D seg_loss?? When doing fusion experiments (Figure 4 in the main paper, Table 2, Table 3), Can you explain which loss option was applied??
It is the following:
Please have a look at this fusion training script (it might need some adjustments to work with the published code):
#!/usr/bin/env python
import os
import os.path as osp
import argparse
import logging
import time
import socket
import warnings
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from xmuda.common.solver.build import build_optimizer, build_scheduler
from xmuda.common.utils.checkpoint import CheckpointerV2
from xmuda.common.utils.logger import setup_logger
from xmuda.common.utils.metric_logger import MetricLogger
from xmuda.common.utils.torch_util import set_random_seed
from xmuda.models.build import build_model_fuse
from xmuda.data.build import build_dataloader
from xmuda.data.utils.validate import validate
from xmuda.models.losses import entropy_loss
def parse_args():
parser = argparse.ArgumentParser(description='xMUDA training')
parser.add_argument(
'--cfg',
dest='config_file',
default='',
metavar='FILE',
help='path to config file',
type=str,
)
parser.add_argument(
'opts',
help='Modify config options using the command-line',
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
return args
def init_metric_logger(metric_list):
new_metric_list = []
for metric in metric_list:
if isinstance(metric, (list, tuple)):
new_metric_list.extend(metric)
else:
new_metric_list.append(metric)
metric_logger = MetricLogger(delimiter=' ')
metric_logger.add_meters(new_metric_list)
return metric_logger
def train(cfg, output_dir='', run_name=''):
# ---------------------------------------------------------------------------- #
# Build models, optimizer, scheduler, checkpointer, etc.
# ---------------------------------------------------------------------------- #
logger = logging.getLogger('xmuda.train')
set_random_seed(cfg.RNG_SEED)
# build fuse model
model_fuse, train_metric_fuse = build_model_fuse(cfg)
logger.info('Build fuse model:\n{}'.format(str(model_fuse)))
num_params = sum(param.numel() for param in model_fuse.parameters())
print('#Parameters: {:.2e}'.format(num_params))
model_fuse = model_fuse.cuda()
# build optimizer
optimizer_fuse = build_optimizer(cfg, model_fuse)
# build lr scheduler
scheduler_fuse = build_scheduler(cfg, optimizer_fuse)
# build checkpointer
# Note that checkpointer will load state_dict of model, optimizer and scheduler.
checkpointer_fuse = CheckpointerV2(model_fuse,
optimizer=optimizer_fuse,
scheduler=scheduler_fuse,
save_dir=output_dir,
logger=logger,
max_to_keep=cfg.TRAIN.MAX_TO_KEEP)
checkpoint_data_fuse = checkpointer_fuse.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)
ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD
# build tensorboard logger (optionally by comment)
if output_dir:
tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))
summary_writer = SummaryWriter(tb_dir)
else:
summary_writer = None
# ---------------------------------------------------------------------------- #
# Train
# ---------------------------------------------------------------------------- #
max_iteration = cfg.SCHEDULER.MAX_ITERATION
start_iteration = checkpoint_data_fuse.get('iteration', 0)
# build data loader
# Reset the random seed again in case the initialization of models changes the random state.
set_random_seed(cfg.RNG_SEED)
train_dataloader_src = build_dataloader(cfg, mode='train', domain='source', start_iteration=start_iteration)
train_dataloader_trg = build_dataloader(cfg, mode='train', domain='target', start_iteration=start_iteration)
val_period = cfg.VAL.PERIOD
val_dataloader = build_dataloader(cfg, mode='val', domain='target') if val_period > 0 else None
best_metric_name = 'best_{}'.format(cfg.VAL.METRIC)
best_metric = {
'2d': checkpoint_data_fuse.get(best_metric_name, None),
}
best_metric_iter = {'2d': -1}
logger.info('Start training from iteration {}'.format(start_iteration))
# add metrics
train_metric_logger = init_metric_logger([train_metric_fuse])
val_metric_logger = MetricLogger(delimiter=' ')
def setup_train():
# set training mode
model_fuse.train()
# reset metric
train_metric_logger.reset()
def setup_validate():
# set evaluate mode
model_fuse.eval()
# reset metric
val_metric_logger.reset()
if cfg.TRAIN.CLASS_WEIGHTS:
class_weights = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS).cuda()
else:
class_weights = None
if cfg.TRAIN.CLASS_WEIGHTS_PL:
class_weights_pl = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS_PL).cuda()
else:
class_weights_pl = None
setup_train()
end = time.time()
train_iter_src = enumerate(train_dataloader_src)
train_iter_trg = enumerate(train_dataloader_trg)
for iteration in range(start_iteration, max_iteration):
# fetch data_batches for source & target
_, data_batch_src = train_iter_src.__next__()
_, data_batch_trg = train_iter_trg.__next__()
data_time = time.time() - end
# copy data from cpu to gpu
if 'SCN' in cfg.DATASET_SOURCE.TYPE and 'SCN' in cfg.DATASET_TARGET.TYPE:
# source
data_batch_src['x'][1] = data_batch_src['x'][1].cuda()
data_batch_src['seg_label'] = data_batch_src['seg_label'].cuda()
data_batch_src['img'] = data_batch_src['img'].cuda()
# target
data_batch_trg['x'][1] = data_batch_trg['x'][1].cuda()
# data_batch_trg['seg_label'] = data_batch_trg['seg_label'].cuda()
data_batch_trg['img'] = data_batch_trg['img'].cuda()
if cfg.TRAIN.XMUDA.lambda_pl > 0:
data_batch_trg['pseudo_label_2d'] = data_batch_trg['pseudo_label_2d'].cuda()
# data_batch_trg['pseudo_label_3d'] = data_batch_trg['pseudo_label_3d'].cuda()
else:
raise NotImplementedError('Only SCN is supported for now.')
optimizer_fuse.zero_grad()
# ---------------------------------------------------------------------------- #
# Train on source
# ---------------------------------------------------------------------------- #
preds_fuse = model_fuse(data_batch_src)
# segmentation loss: cross entropy
seg_loss_src_fuse = F.cross_entropy(preds_fuse['seg_logit'], data_batch_src['seg_label'], weight=class_weights)
train_metric_logger.update(seg_loss_src_fuse=seg_loss_src_fuse)
loss_fuse = seg_loss_src_fuse
if cfg.TRAIN.XMUDA.lambda_xm_src > 0:
# cross-modal loss: KL divergence
seg_logit_2d = preds_fuse['seg_logit_2d']
seg_logit_3d = preds_fuse['seg_logit_3d']
xm_loss_src_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),
F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
reduction='none').sum(1).mean()
xm_loss_src_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),
F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
reduction='none').sum(1).mean()
train_metric_logger.update(xm_loss_src_2d=xm_loss_src_2d,
xm_loss_src_3d=xm_loss_src_3d)
loss_fuse += cfg.TRAIN.XMUDA.lambda_xm_src * (xm_loss_src_2d + xm_loss_src_3d)
# update metric (e.g. IoU)
with torch.no_grad():
train_metric_fuse.update_dict(preds_fuse, data_batch_src)
# backward
loss_fuse.backward()
# ---------------------------------------------------------------------------- #
# Train on target
# ---------------------------------------------------------------------------- #
preds_fuse = model_fuse(data_batch_trg)
loss_fuse = []
if cfg.TRAIN.XMUDA.lambda_xm_trg > 0:
# cross-modal loss: KL divergence
seg_logit_2d = preds_fuse['seg_logit_2d']
seg_logit_3d = preds_fuse['seg_logit_3d']
xm_loss_trg_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),
F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
reduction='none').sum(1).mean()
xm_loss_trg_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),
F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
reduction='none').sum(1).mean()
train_metric_logger.update(xm_loss_trg_2d=xm_loss_trg_2d,
xm_loss_trg_3d=xm_loss_trg_3d)
loss_fuse.append(cfg.TRAIN.XMUDA.lambda_xm_trg * xm_loss_trg_2d)
loss_fuse.append(cfg.TRAIN.XMUDA.lambda_xm_trg * xm_loss_trg_3d)
if cfg.TRAIN.XMUDA.lambda_pl > 0:
# self-training loss with pseudo labels
# Note that the fused labels must be stored in 'pseudo_label_2d'
pl_loss_trg_fuse = F.cross_entropy(preds_fuse['seg_logit'], data_batch_trg['pseudo_label_2d'],
weight=class_weights_pl)
train_metric_logger.update(pl_loss_trg_fuse=pl_loss_trg_fuse)
loss_fuse.append(cfg.TRAIN.XMUDA.lambda_pl * pl_loss_trg_fuse)
if cfg.TRAIN.XMUDA.lambda_minent > 0:
# MinEnt
minent_loss_trg_fuse = entropy_loss(F.softmax(preds_fuse['seg_logit'], dim=1))
train_metric_logger.update(minent_loss_trg_2d=minent_loss_trg_fuse)
loss_fuse.append(cfg.TRAIN.XMUDA.lambda_minent * minent_loss_trg_fuse)
sum(loss_fuse).backward()
optimizer_fuse.step()
batch_time = time.time() - end
train_metric_logger.update(time=batch_time, data=data_time)
# log
cur_iter = iteration + 1
if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD == 0):
logger.info(
train_metric_logger.delimiter.join(
[
'iter: {iter:4d}',
'{meters}',
'lr: {lr:.2e}',
'max mem: {memory:.0f}',
]
).format(
iter=cur_iter,
meters=str(train_metric_logger),
lr=optimizer_fuse.param_groups[0]['lr'],
memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
)
)
# summary
if summary_writer is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:
keywords = ('loss', 'acc', 'iou')
for name, meter in train_metric_logger.meters.items():
if all(k not in name for k in keywords):
continue
summary_writer.add_scalar('train/' + name, meter.avg, global_step=cur_iter)
# checkpoint
if (ckpt_period > 0 and cur_iter % ckpt_period == 0) or cur_iter == max_iteration:
checkpoint_data_fuse['iteration'] = cur_iter
checkpoint_data_fuse[best_metric_name] = best_metric['2d']
checkpointer_fuse.save('model_fuse_{:06d}'.format(cur_iter), **checkpoint_data_fuse)
# ---------------------------------------------------------------------------- #
# validate for one epoch
# ---------------------------------------------------------------------------- #
if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):
start_time_val = time.time()
setup_validate()
validate(cfg,
model_fuse,
None,
val_dataloader,
val_metric_logger)
epoch_time_val = time.time() - start_time_val
logger.info('Iteration[{}]-Val {} total_time: {:.2f}s'.format(
cur_iter, val_metric_logger.summary_str, epoch_time_val))
# summary
if summary_writer is not None:
keywords = ('loss', 'acc', 'iou')
for name, meter in val_metric_logger.meters.items():
if all(k not in name for k in keywords):
continue
summary_writer.add_scalar('val/' + name, meter.avg, global_step=cur_iter)
# best validation
for modality in ['2d']:
cur_metric_name = cfg.VAL.METRIC + '_' + modality
if cur_metric_name in val_metric_logger.meters:
cur_metric = val_metric_logger.meters[cur_metric_name].global_avg
if best_metric[modality] is None or best_metric[modality] < cur_metric:
best_metric[modality] = cur_metric
best_metric_iter[modality] = cur_iter
# restore training
setup_train()
scheduler_fuse.step()
end = time.time()
for modality in ['2d']:
logger.info('Best val-{}-{} = {:.2f} at iteration {}'.format(modality.upper(),
cfg.VAL.METRIC,
best_metric[modality] * 100,
best_metric_iter[modality]))
def main():
args = parse_args()
# load the configuration
# import on-the-fly to avoid overwriting cfg
from xmuda.common.config import purge_cfg
from xmuda.config.xmuda import cfg
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
purge_cfg(cfg)
cfg.freeze()
output_dir = cfg.OUTPUT_DIR
# replace '@' with config path
if output_dir:
config_path = osp.splitext(args.config_file)[0]
output_dir = output_dir.replace('@', config_path.replace('configs/', ''))
if osp.isdir(output_dir):
warnings.warn('Output directory exists.')
os.makedirs(output_dir, exist_ok=True)
# run name
timestamp = time.strftime('%m-%d_%H-%M-%S')
hostname = socket.gethostname()
run_name = '{:s}.{:s}'.format(timestamp, hostname)
logger = setup_logger('xmuda', output_dir, comment='train.{:s}'.format(run_name))
logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))
logger.info(args)
logger.info('Loaded configuration file {:s}'.format(args.config_file))
logger.info('Running with config:\n{}'.format(cfg))
# in fusion, dual head is necessary to apply cross-modal loss
assert cfg.MODEL_2D.DUAL_HEAD == cfg.MODEL_3D.DUAL_HEAD
# check if there is at least one loss on target set
assert cfg.TRAIN.XMUDA.lambda_xm_src > 0 or cfg.TRAIN.XMUDA.lambda_xm_trg > 0 or cfg.TRAIN.XMUDA.lambda_pl > 0 or \
cfg.TRAIN.XMUDA.lambda_minent > 0
train(cfg, output_dir, run_name)
if __name__ == '__main__':
main()
Thanks for your reply with the code.
I observed something different from my expectation. Based on the below figure and equations 2 and 3 (main paper), my understanding is that the code and figure arrow direction are not matched. Q1. Can you explain this part?? Have I missed something?
seg_logit_2d = preds_fuse['seg_logit_2d']
seg_logit_3d = preds_fuse['seg_logit_3d']
# My understanding based on the figure
xm_loss_src_2d = F.kl_div(F.log_softmax(preds_fuse['seg_logit'], dim=1),
F.softmax(seg_logit_2d.detach(), dim=1),
reduction='none').sum(1).mean()
xm_loss_src_3d = F.kl_div(F.log_softmax(preds_fuse['seg_logit'], dim=1),
F.softmax(seg_logit_3d.detach(), dim=1),
reduction='none').sum(1).mean()
# Your provided code
xm_loss_src_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),
F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
reduction='none').sum(1).mean()
xm_loss_src_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),
F.softmax(preds_fuse['seg_logit'].detach(), dim=1),
reduction='none').sum(1).mean()
Hi, The D_KL annotation might be confusing as it is defined D_KL(P||Q) where P is the target. It is worth it to have a look at the definition of Kullback-Leibler divergence: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
The following explains it:
Consider two probability distributions P and Q. Usually, P represents the data, the observations, or a measured probability distribution. Distribution Q represents instead a theory, a model, a description or an approximation of P.
This means that P is the target, in our case P_fuse. The arrow direction models the flow of information, i.e. P_fuse teaches P_3d->fuse and P_2D->fuse.
Hope that helps.
Thanks, I clearly understood KL loss.
Thank you for sharing the code. It is very great! Do you have any plan to release the code of Vanilla Fusion or xMUDA Fusion?