Chris-hughes10 / Yolov7-training

A clean, modular implementation of the Yolov7 model family, which uses the official pretrained weights, with utilities for training the model on custom (non-COCO) tasks.
GNU General Public License v3.0
116 stars 35 forks source link

Error when training a _w6 model #20

Open d-colwell opened 10 months ago

d-colwell commented 10 months ago

Hi, Loading the a w6 config model using the below code:

    model = create_yolov7_model(
        architecture="yolov7-w6", num_classes=num_classes, pretrained=pretrained
    )

gives you an error in the constructor of the Yolov7DetectionHeadWithAux class. The cause of the error is that use_implicit_modules is being passed to this classes constructor, which is not expecting this arg. To fix this, you can remove the True from the config on line 470 of the model_configs.py, but that exposes a second error. When use_implicit_modules is defaulted, the forward pass of the network throws an index out of bounds error. I have patched this by defaulting the parameter to true when called from the Yolov7DetectionHeadWithAux constructor, but im not sure if this is the correct approach

parlaynu commented 3 months ago

I just ran into the same problem. Fixed it by adding the use_implicit_modules argument to the constructor of Yolov7DetectionHeadWithAux so it matches Yolov7DetectionHead, and passing it on to the call to super().__init__ as in the below diff. No need to change the configs.

class Yolov7DetectionHeadWithAux(Yolov7DetectionHead):
    def __init__(
        self,
         num_classes=80,
         anchor_sizes_per_layer=(),
         strides: torch.Tensor = (),
+        use_implicit_modules: bool = True,
         in_channels_per_layer=(),
     ):
         super().__init__(
-            num_classes, anchor_sizes_per_layer, strides, in_channels_per_layer
+            num_classes, anchor_sizes_per_layer, strides, use_implicit_modules, in_channels_per_layer
         )
         self.m2 = nn.ModuleList(
             nn.Conv2d(in_channels, self.num_outputs * self.num_anchor_sizes, 1)