facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.52k stars 7.48k forks source link

CommonMetricPrinter should respect smoothing_hint of losses #3413

Open Poulinakis-Konstantinos opened 3 years ago

Poulinakis-Konstantinos commented 3 years ago

I am observing this unexpected behavior which might be caused due to some bug . I am trying to train a RetinaNet with a custom train loop based on plain_train_loop.py .

Inside do_train function I use CommonMetricPrinter and JSONWriter to log my training metrics. To provide metrics for the loggers I usestorage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced, smoothing_hint=False) . Afterwards I also print(losses_reduced) as a check.

The values print command outputs are most of the time different than the values logged by the writers.

I am wondering if I am missing something or if this is indeed a bug. I also disabled smoothing_hint in case it was the root of the problem.

Instructions To Reproduce the šŸ› Bug:

  1. Full runnable code or full changes you made:
    
    from tqdm import tqdm 
    import logging
    import os
    from collections import OrderedDict
    import torch
    from torch.nn.parallel import DistributedDataParallel

import detectron2.utils.comm as comm from detectron2.checkpoint import Checkpointer, DetectionCheckpointer, PeriodicCheckpointer from detectron2.config import get_cfg from detectron2.data import (MetadataCatalog, build_detection_test_loader, build_detection_train_loader, DatasetMapper ) from detectron2.engine import default_argument_parser, default_setup, launch from detectron2.evaluation import ( COCOEvaluator, DatasetEvaluators, inference_on_dataset, print_csv_format, ) from detectron2.modeling import build_model from detectron2.solver import build_lr_scheduler, build_optimizer from detectron2.utils.events import ( CommonMetricPrinter, EventStorage, JSONWriter, TensorboardXWriter, ) from detectron2.data import MetadataCatalog from detectron2.data.catalog import DatasetCatalog from detectron2.data.datasets import register_coco_instances

logger = logging.getLogger("detectron2") from detectron2 import model_zoo from tensorboard import program

def get_COCO_evaluator( dataset_name, output_folder=None): if output_folder is None: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") return COCOEvaluator(dataset_name, tasks=('bbox',), distributed=True, output_dir = output_folder)

from custom_callbacks import callback_best_weights_mAP

adding a new parameter current_iteration

def do_test(cfg, model, current_iteration): results = OrderedDict()

Create a JSON writer for logging evaluation results

if comm.is_main_process():
    writers= [
        CommonMetricPrinter(cfg.SOLVER.MAX_ITER),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "evaluation_metrics.json"))
    ]
else: 
     writers = []

# Initiate EventStorage 
with EventStorage(current_iteration) as storage:
    for dataset_name in cfg.DATASETS.TEST:
        data_loader = build_detection_test_loader(cfg, dataset_name)
        evaluator = get_COCO_evaluator(
            dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
        )

        results_i = inference_on_dataset(model, data_loader, evaluator)
        results[dataset_name] = results_i
        # Log the metrics computed by COCOEvaluator in the jwriter
        storage.put_scalar("AP-Bad", results_i['bbox']['AP-Bad'])
        storage.put_scalar("AP-Good", results_i['bbox']['AP-Good'])
        storage.put_scalar("AP75", results_i['bbox']['AP75'])
        storage.put_scalar("AP50", results_i['bbox']['AP50'])
        storage.put_scalar("mAP", results_i['bbox']['AP'])

        if comm.is_main_process():
            logger.info("Evaluation results for {} in csv format:".format(dataset_name))
            print_csv_format(results_i)
            # writes what is currently in storage.latest_with_smoothing_hint. Use storage.put_scalars to add values
            for writer in writers:
                writer.write()

        # Custom callback function call        
        callback_best_weights_mAP(cfg, model, current_value = results_i['bbox']['AP'], model_name = MODEL_NAME )
        print("History mAP: ", storage.history('mAP'))
        print("Results[dataset_name] : ",results[dataset_name])

    if len(results) == 1:
        results = list(results.values())[0]
return results

def do_train(cfg, model, resume=False): model.train( ) optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer)

checkpointer = DetectionCheckpointer(
    model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
)
start_iter = (
    checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
)
max_iter = cfg.SOLVER.MAX_ITER

periodic_checkpointer = PeriodicCheckpointer(
    checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)

writers = (
    [
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ]
    if comm.is_main_process()
    else []
)

data_loader = build_detection_train_loader(cfg) # loading train data
for dataset_name in cfg.DATASETS.TEST:
    data_loader_val = build_detection_test_loader(cfg, dataset_name,DatasetMapper(cfg,True)) # loading validation data
logger.info("Starting training from iteration {}".format(start_iter))

with EventStorage(start_iter) as storage:
    for data, iteration in zip(data_loader, range(start_iter, max_iter)):
        iteration = iteration + 1
        storage.step()
        loss_dict = model(data)
        losses = sum(loss for loss in loss_dict.values())
        assert torch.isfinite(losses).all(), loss_dict

        loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        if comm.is_main_process():

# THIS IS WHERE THE BUG MIGHT BE storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced, smoothing_hint=False) print("TRAIN LOSSES :", losses_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        #storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
        scheduler.step()

        if (
            cfg.TEST.EVAL_PERIOD > 0
            and iteration % cfg.TEST.EVAL_PERIOD == 0
            and iteration != (max_iter - 1)
        ):
            if iteration == max_iter :
                do_test(cfg, model, max_iter)
                comm.synchronize()
            else :
                do_test(cfg, model, iteration)
                comm.synchronize()

        for writer in writers:
            writer.write()
        periodic_checkpointer.step(iteration)

     #### The model computes losses when in training mode. So model(validation_inputs) will give you losses.
        total = 0
        for idx, inputs in tqdm(enumerate(data_loader_val)):  
            val_loss_dict = model(inputs)
            # Reduce the dictionary to only the nescecarry value (cut other info)
            val_loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(val_loss_dict).items()}
           #val_total_loss = sum( val_loss for val_loss in val_loss_dict.values() ) 
            val_total_loss_reduced = sum(val_loss for val_loss in val_loss_dict_reduced.values() ) 
            total = total  + val_total_loss_reduced

        val_total_avg_loss = total/len(data_loader_val)
        print("VALIDATION TOTAL LOSS IS : ", val_total_avg_loss)

def setup(args): """ Create configs and perform basic setups. """ OUTPUT_DIR = 'Output_V_1.0'

# Register Datasets ( probably already registered datasets don't need to be reregistered)
register_coco_instances("my_dataset_train", {}, "./cvat_obj_det_apricots/train.json", "./cvat_obj_det_apricots")
register_coco_instances("my_dataset_val", {}, "./cvat_obj_det_apricots/val.json", "./cvat_obj_det_apricots")
register_coco_instances("my_dataset_test", {}, "./cvat_obj_det_apricots/test.json", "./cvat_obj_det_apricots")

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_101_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ('my_dataset_train',)
cfg.DATASETS.TEST = ("my_dataset_val",)

cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/retinanet_R_101_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.RETINANET.NUM_CLASSES = 2 # <---  DO NOT FORGET TO SET NUMBER OF CLASSES! (SAME TO BE INSERTED IN MAIN.PY)

cfg.SOLVER.MAX_ITER = 100 # Adjust up if val mAP is still rising, adjust down if overfit
cfg.SOLVER.MOMENTUM = 0.9
cfg.SOLVER.WARMUP_ITERS = 50
cfg.SOLVER.IMS_PER_BATCH = 6
cfg.SOLVER.LR_POLICY = 'step' # Test why Policy is not being applied and lr is constant.

cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.GAMMA = 0.1
cfg.TEST.EVAL_PERIOD = 50
cfg.OUTPUT_DIR = OUTPUT_DIR

cfg.merge_from_file(args.config_file)

cfg.merge_from_list(args.opts)

cfg.freeze()
default_setup(
   cfg, args
)  # if you don't like any of the default setup, write your own setup code

return cfg

def main(args): global MODEL_NAME MODEL_NAME = 'RetinaNet_V_1.0'

cfg = setup(args)
model = build_model(cfg)
logger.info("Model:\n{}".format(model))
if args.eval_only:
    DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
        cfg.MODEL.WEIGHTS, resume=args.resume
    )
    return do_test(cfg, model)

distributed = comm.get_world_size() > 1
if distributed:
    model = DistributedDataParallel(
        model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
    )

do_train(cfg, model)
return do_test(cfg, model, current_iteration)

if name == "main": args = default_argument_parser().parse_args() print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), )>

2. What exact command you run:
 I am just using cli to run the script python3 RetinaNet_Train_Loop

3. __Full logs__ or other relevant observations:

The command line
![image](https://user-images.githubusercontent.com/75034778/131112485-90539ae0-22fb-49c5-9726-61c7eb02a5e2.png)

The metrics.json log 
![image](https://user-images.githubusercontent.com/75034778/131112644-f2239329-b3d4-468d-a276-1faced576c15.png)

4. please simplify the steps as much as possible so they do not require additional resources to
   run, such as a private dataset.

## Expected behavior:
Expected behavior would be to observe the same value logged from the print(losses_reduced) command and the writers.

## Environment:
sys.platform linux Python 3.8.5 packaged by conda-forge (default, Sep 24 2020, 16:55:52) [GCC 7.5.0] numpy 1.20.3 detectron2 0.5 @/home/air/anaconda3/envs/cvat_detectron2_new/lib/python3.8/site-packages/detectron2 Compiler GCC 7.3 CUDA compiler CUDA 11.1 detectron2 arch flags 3.7, 5.0, 5.2, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6 DETECTRON2_ENV_MODULE PyTorch 1.9.0 @/home/air/anaconda3/envs/cvat_detectron2_new/lib/python3.8/site-packages/torch PyTorch debug build False GPU available Yes GPU 0 GeForce RTX 2080 Ti (arch=7.5) CUDA_HOME /usr Pillow 8.3.1 torchvision 0.10.0 @/home/air/anaconda3/envs/cvat_detectron2_new/lib/python3.8/site-packages/torchvision torchvision arch flags 3.5, 5.0, 6.0, 7.0, 7.5, 8.0, 8.6 fvcore 0.1.5.post20210812 iopath 0.1.8 cv2 Not found

PyTorch built with:

Poulinakis-Konstantinos commented 3 years ago

I have also observed this behavior in another metric I tried to log .

Bear in mind that for the first iteration the output is always the same for both the print and the logger. After the 1st step differences start building.

My first guess would be that the smoothing_hint is actually not disabled even though I set it as False storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced, smoothing_hint=False) .

ppwwyyxx commented 3 years ago

Thanks for reporting. I found that CommonMetricPrinter always smooth the losses regardless of smoothing_hint.

This behavior was OK since losses should almost always be smoothed. But I think letting it respect smoothing_hint would be less confusing.