open-mmlab / mmagic

OpenMMLab Multimodal Advanced, Generative, and Intelligent Creation Toolbox. Unlock the magic 🪄: Generative-AI (AIGC), easy-to-use APIs, awsome model zoo, diffusion models, for text-to-image generation, image/video restoration/enhancement, etc.
https://mmagic.readthedocs.io/en/latest/
Apache License 2.0
6.89k stars 1.06k forks source link

[Bug] Can not use `ConcatImageVisualizer` in train step. / 在Train阶段可视化存在bug #2118

Open Pengyang233 opened 8 months ago

Pengyang233 commented 8 months ago

Prerequisite

Task

I have modified the scripts/configs, or I'm working on my own tasks/models/datasets.

Branch

main branch https://github.com/open-mmlab/mmagic

Environment

English is bad, so I write it in Chinese.

BasicViaualizationHook 里的_after_iterTrainValTestafter_iter阶段被调用,来进行相关可视化操作,通过接受 outputs 来进行可视化,这里outputs需要是DataSample 类型

BaseEditModel 里,IterTrainLoopEpochBasedLoopoutputs 是仅含有 loss 的字典,不是DataSample类型,所以不能在 BasicViaualizationHook里的_after_iter函数被解析,从而报错。

下面是我的修改,有点low,修改位置是mmagic/engine/hooks/visualization_hook.pyBasicVisualizationHook._after_iter 函数

        if mode == 'train':  #* 20240126 bug fix. Train 阶段传入的 outputs 是只含有 'loss' 的字典,不是DataSample类型。
            if self.every_n_inner_iters(batch_idx, interval):
                data_sample_0 = data_batch['data_samples'][0]  #* 只存每个 batch 的第一个就行了吧?
                data_sample_0.set_tensor_data({'input': data_batch['inputs'][0]})
                runner.visualizer.add_datasample(data_sample_0, step=runner.iter)
        else:
            if self.every_n_inner_iters(batch_idx, interval):
                for data_sample in outputs:
                    runner.visualizer.add_datasample(data_sample, step=runner.iter)

Reproduces the problem - code sample

vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( scope='mmagic', type='ConcatImageVisualizer', vis_backends=vis_backends, pixel_range = { 'img_vis': [0,1], 'gt_img': [0,1] }, fn_key='gt_path', img_keys=['gt_img', 'img_vis'], bgr2rgb=False, ) custom_hooks = [dict(scope='mmagic',type='BasicVisualizationHook', interval=1, on_train=True, on_val=False)]

Reproduces the problem - command or script

CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29999 tools/dist_train.sh \ configs/gaussSharpen_noAdjust.py 4

Reproduces the problem - error message

Traceback (most recent call last): File "/mmagic/mmagic/visualization/concat_visualizer.py", line 70, in add_datasample **data_sample.to_dict(), AttributeError: 'str' object has no attribute 'to_dict'

Additional information

No response