Closed josh3255 closed 1 year ago
Hi, usually, errors like yours happen when people use OpenMMLab 2.0 and 1.0 codes together. You may have to follow this document to update your MMClassification.
@xin-li-67 thank you :) it's work well.
but change backbone According to this guideline, the delete parameter is used when using mmcls.convnext, but follow it the same way, get this error
TypeError: class `YOLODetector` in mmyolo/models/detectors/yolo_detector.py: class `ConvNeXt` in mmcls/models/backbones/convnext.py: __init__() got an unexpected keyword argument '_delete_'
Hi, can you provide a screenshot of your modified config file?? This may be easier to debug~
@xin-li-67 Thank you for your help!
_base_ = '../../_base_/default_runtime.py'
# dataset settings
data_root = '/mmyolo/data/detection/detection/'
dataset_type = 'YOLOv5CocoDataset'
# parameters that often need to be modified
img_scale = (640, 640) # width, height
max_epochs = 300
save_epoch_intervals = 5
train_batch_size_per_gpu = 4
train_num_workers = 8
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
val_batch_size_per_gpu = 1
val_num_workers = 2
deepen_factor = 0.67
widen_factor = 0.75
# only on Val
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
# different from yolov5
anchors = [
[(12, 16), (19, 36), (40, 28)], # P3/8
[(36, 75), (76, 55), (72, 146)], # P4/16
[(142, 110), (192, 243), (459, 401)] # P5/32
]
strides = [8, 16, 32]
num_det_layers = 3
num_classes = 3
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth' # noqa
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
_delete_=True,
type='mmcls.ConvNeXt',
arch='small',
out_indices=(0, 1, 2, 3),
drop_path_rate=0.6,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained',
checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(
type='YOLOv7PAFPN',
block_cfg=dict(
type='ELANBlock',
middle_ratio=0.5,
block_ratio=0.25,
num_blocks=4,
num_convs_in_block=1),
upsample_feats_cat_first=False,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[512, 1024, 1024],
# The real output channel will be multiplied by 2
out_channels=[128, 256, 512],
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv7Head',
head_module=dict(
type='YOLOv7HeadModule',
num_classes=num_classes,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
featmap_strides=strides,
num_base_priors=3),
prior_generator=dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=anchors,
strides=strides),
# scaled based on number of detection layers
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=0.3 * (num_classes / 80 * 3 / num_det_layers)),
loss_bbox=dict(
type='IoULoss',
iou_mode='ciou',
bbox_format='xywh',
reduction='mean',
loss_weight=0.05 * (3 / num_det_layers),
return_iou=True),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=0.7 * ((img_scale[0] / 640)**2 * 3 / num_det_layers)),
obj_level_weights=[4., 1., 0.4],
# BatchYOLOv7Assigner params
prior_match_thr=4.,
simota_candidate_topk=10,
simota_iou_weight=3.0,
simota_cls_weight=1.0),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300))
pre_transform = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True)
]
mosiac4_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.2, # note
scaling_ratio_range=(0.1, 2.0), # note
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
mosiac9_pipeline = [
dict(
type='Mosaic9',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.2, # note
scaling_ratio_range=(0.1, 2.0), # note
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
randchoice_mosaic_pipeline = dict(
type='RandomChoice',
transforms=[mosiac4_pipeline, mosiac9_pipeline],
prob=[0.8, 0.2])
train_pipeline = [
*pre_transform,
randchoice_mosaic_pipeline,
dict(
type='YOLOv5MixUp',
alpha=8.0, # note
beta=8.0, # note
prob=0.15,
pre_transform=[*pre_transform, randchoice_mosaic_pipeline]),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='yolov5_collate'), # FASTER
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/train.json',
data_prefix=dict(img='images/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img='images/'),
ann_file='annotations/val.json',
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
param_scheduler = None
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=0.01,
momentum=0.937,
weight_decay=0.0005,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv7OptimWrapperConstructor')
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='cosine',
lr_factor=0.1, # note
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
save_param_scheduler=False,
interval=1,
save_best='auto',
max_keep_ckpts=3))
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10), # Can be accelerated
ann_file=data_root + 'annotations/val.json',
metric='bbox')
test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(270, 1)])
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49)
]
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# randomness = dict(seed=1, deterministic=True)
If remove the delete parameter, it works, but get the following error.
Traceback (most recent call last):
File "tools/train.py", line 116, in <module>
main()
File "tools/train.py", line 112, in main
runner.train()
File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/runner.py", line 1706, in train
model = self.train_loop.run() # type: ignore
File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/loops.py", line 96, in run
self.run_epoch()
File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/loops.py", line 112, in run_epoch
self.run_iter(idx, data_batch)
File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/loops.py", line 129, in run_iter
data_batch, optim_wrapper=self.runner.optim_wrapper)
File "/opt/conda/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
losses = self._run_forward(data, mode='loss') # type: ignore
File "/opt/conda/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 326, in _run_forward
results = self(**data, mode=mode)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 92, in forward
return self.loss(inputs, data_samples)
File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/single_stage.py", line 77, in loss
x = self.extract_feat(batch_inputs)
File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/single_stage.py", line 148, in extract_feat
x = self.neck(x)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mmyolo/mmyolo/models/necks/base_yolo_neck.py", line 224, in forward
assert len(inputs) == len(self.in_channels)
@xin-li-67 Thank you for your help!
_base_ = '../../_base_/default_runtime.py' # dataset settings data_root = '/mmyolo/data/detection/detection/' dataset_type = 'YOLOv5CocoDataset' # parameters that often need to be modified img_scale = (640, 640) # width, height max_epochs = 300 save_epoch_intervals = 5 train_batch_size_per_gpu = 4 train_num_workers = 8 # persistent_workers must be False if num_workers is 0. persistent_workers = True val_batch_size_per_gpu = 1 val_num_workers = 2 deepen_factor = 0.67 widen_factor = 0.75 # only on Val batch_shapes_cfg = dict( type='BatchShapePolicy', batch_size=val_batch_size_per_gpu, img_size=img_scale[0], size_divisor=32, extra_pad_ratio=0.5) # different from yolov5 anchors = [ [(12, 16), (19, 36), (40, 28)], # P3/8 [(36, 75), (76, 55), (72, 146)], # P4/16 [(142, 110), (192, 243), (459, 401)] # P5/32 ] strides = [8, 16, 32] num_det_layers = 3 num_classes = 3 custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False) checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth' # noqa # single-scale training is recommended to # be turned on, which can speed up training. env_cfg = dict(cudnn_benchmark=True) model = dict( type='YOLODetector', data_preprocessor=dict( type='YOLOv5DetDataPreprocessor', mean=[0., 0., 0.], std=[255., 255., 255.], bgr_to_rgb=True), backbone=dict( _delete_=True, type='mmcls.ConvNeXt', arch='small', out_indices=(0, 1, 2, 3), drop_path_rate=0.6, layer_scale_init_value=1.0, gap_before_final_norm=False, init_cfg=dict( type='Pretrained', checkpoint=checkpoint_file, prefix='backbone.')), neck=dict( type='YOLOv7PAFPN', block_cfg=dict( type='ELANBlock', middle_ratio=0.5, block_ratio=0.25, num_blocks=4, num_convs_in_block=1), upsample_feats_cat_first=False, deepen_factor=deepen_factor, widen_factor=widen_factor, in_channels=[512, 1024, 1024], # The real output channel will be multiplied by 2 out_channels=[128, 256, 512], norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='SiLU', inplace=True)), bbox_head=dict( type='YOLOv7Head', head_module=dict( type='YOLOv7HeadModule', num_classes=num_classes, widen_factor=widen_factor, in_channels=[256, 512, 1024], featmap_strides=strides, num_base_priors=3), prior_generator=dict( type='mmdet.YOLOAnchorGenerator', base_sizes=anchors, strides=strides), # scaled based on number of detection layers loss_cls=dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=0.3 * (num_classes / 80 * 3 / num_det_layers)), loss_bbox=dict( type='IoULoss', iou_mode='ciou', bbox_format='xywh', reduction='mean', loss_weight=0.05 * (3 / num_det_layers), return_iou=True), loss_obj=dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=0.7 * ((img_scale[0] / 640)**2 * 3 / num_det_layers)), obj_level_weights=[4., 1., 0.4], # BatchYOLOv7Assigner params prior_match_thr=4., simota_candidate_topk=10, simota_iou_weight=3.0, simota_cls_weight=1.0), test_cfg=dict( multi_label=True, nms_pre=30000, score_thr=0.001, nms=dict(type='nms', iou_threshold=0.65), max_per_img=300)) pre_transform = [ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), dict(type='LoadAnnotations', with_bbox=True) ] mosiac4_pipeline = [ dict( type='Mosaic', img_scale=img_scale, pad_val=114.0, pre_transform=pre_transform), dict( type='YOLOv5RandomAffine', max_rotate_degree=0.0, max_shear_degree=0.0, max_translate_ratio=0.2, # note scaling_ratio_range=(0.1, 2.0), # note # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2), border_val=(114, 114, 114)), ] mosiac9_pipeline = [ dict( type='Mosaic9', img_scale=img_scale, pad_val=114.0, pre_transform=pre_transform), dict( type='YOLOv5RandomAffine', max_rotate_degree=0.0, max_shear_degree=0.0, max_translate_ratio=0.2, # note scaling_ratio_range=(0.1, 2.0), # note # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2), border_val=(114, 114, 114)), ] randchoice_mosaic_pipeline = dict( type='RandomChoice', transforms=[mosiac4_pipeline, mosiac9_pipeline], prob=[0.8, 0.2]) train_pipeline = [ *pre_transform, randchoice_mosaic_pipeline, dict( type='YOLOv5MixUp', alpha=8.0, # note beta=8.0, # note prob=0.15, pre_transform=[*pre_transform, randchoice_mosaic_pipeline]), dict(type='YOLOv5HSVRandomAug'), dict(type='mmdet.RandomFlip', prob=0.5), dict( type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', 'flip_direction')) ] train_dataloader = dict( batch_size=train_batch_size_per_gpu, num_workers=train_num_workers, persistent_workers=persistent_workers, pin_memory=True, sampler=dict(type='DefaultSampler', shuffle=True), collate_fn=dict(type='yolov5_collate'), # FASTER dataset=dict( type=dataset_type, data_root=data_root, ann_file='annotations/train.json', data_prefix=dict(img='images/'), filter_cfg=dict(filter_empty_gt=False, min_size=32), pipeline=train_pipeline)) test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), dict(type='YOLOv5KeepRatioResize', scale=img_scale), dict( type='LetterResize', scale=img_scale, allow_scale_up=False, pad_val=dict(img=114)), dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'), dict( type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'pad_param')) ] val_dataloader = dict( batch_size=val_batch_size_per_gpu, num_workers=val_num_workers, persistent_workers=persistent_workers, pin_memory=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, test_mode=True, data_prefix=dict(img='images/'), ann_file='annotations/val.json', pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg)) test_dataloader = val_dataloader param_scheduler = None optim_wrapper = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=0.01, momentum=0.937, weight_decay=0.0005, nesterov=True, batch_size_per_gpu=train_batch_size_per_gpu), constructor='YOLOv7OptimWrapperConstructor') default_hooks = dict( param_scheduler=dict( type='YOLOv5ParamSchedulerHook', scheduler_type='cosine', lr_factor=0.1, # note max_epochs=max_epochs), checkpoint=dict( type='CheckpointHook', save_param_scheduler=False, interval=1, save_best='auto', max_keep_ckpts=3)) val_evaluator = dict( type='mmdet.CocoMetric', proposal_nums=(100, 1, 10), # Can be accelerated ann_file=data_root + 'annotations/val.json', metric='bbox') test_evaluator = val_evaluator train_cfg = dict( type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=save_epoch_intervals, dynamic_intervals=[(270, 1)]) custom_hooks = [ dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0001, update_buffers=True, strict_load=False, priority=49) ] val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') # randomness = dict(seed=1, deterministic=True)
If remove the delete parameter, it works, but get the following error.
Traceback (most recent call last): File "tools/train.py", line 116, in <module> main() File "tools/train.py", line 112, in main runner.train() File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/runner.py", line 1706, in train model = self.train_loop.run() # type: ignore File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/loops.py", line 96, in run self.run_epoch() File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/loops.py", line 112, in run_epoch self.run_iter(idx, data_batch) File "/opt/conda/lib/python3.7/site-packages/mmengine/runner/loops.py", line 129, in run_iter data_batch, optim_wrapper=self.runner.optim_wrapper) File "/opt/conda/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step losses = self._run_forward(data, mode='loss') # type: ignore File "/opt/conda/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 326, in _run_forward results = self(**data, mode=mode) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 92, in forward return self.loss(inputs, data_samples) File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/single_stage.py", line 77, in loss x = self.extract_feat(batch_inputs) File "/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/single_stage.py", line 148, in extract_feat x = self.neck(x) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/mmyolo/mmyolo/models/necks/base_yolo_neck.py", line 224, in forward assert len(inputs) == len(self.in_channels)
Hi josh, as you saw in the replace backbone document, the _delete_
is set to replace the original design the _base_
file. However, you used the base_runtime.py
as the _base_
file in your yolov7_convnext.py
which has no implementation on the model. I believe this is the cause of your issue. If you replace the
base` with YOLOv7's file and remove the redundant contents, it will probably not come up.
For the other error, I think it may be related to the output dims are not equal to your neck input dimension.
@xin-li-67 Thank you. Problem solved :)
Prerequisite
🐞 Describe the bug
When I installed mmcls and modified the config file of my yolov7 and changed the backbone to mmcls.convnext, the following error occurs.
yolov7_convnext.py
mmcls installation :
mmyolo installation :
init.py of mmcls
init.py of mmyolo
Environment
Additional information
No response