ultralytics / yolov5

YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite
https://docs.ultralytics.com
GNU Affero General Public License v3.0
49.47k stars 16.07k forks source link

Apply Transformer in the backbone #2329

Closed dingyiwei closed 3 years ago

dingyiwei commented 3 years ago

🚀 Feature

Transformer is popular in NLP, and now is also applied on CV. I added C3TR just by replacing the sequential self.m in C3 with a Transformer block, which could reduce GFlOPs and make Yolo achieve a better result.

Motivation

Pitch

I add 3 classes in https://github.com/dingyiwei/yolov5/blob/Transformer/models/common.py :

class TransformerLayer(nn.Module):
    def __init__(self, c, num_heads):
        super().__init__()

        self.ln1 = nn.LayerNorm(c)
        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.ln2 = nn.LayerNorm(c)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)

    def forward(self, x):
        x_ = self.ln1(x)
        x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x
        x = self.ln2(x)
        x = self.fc2(self.fc1(x)) + x
        return x

class TransformerBlock(nn.Module):
    def __init__(self, c1, c2, num_heads, num_layers):
        super().__init__()

        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)
        self.linear = nn.Linear(c2, c2)
        self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
        self.c2 = c2

    def forward(self, x):
        if self.conv is not None:
            x = self.conv(x)
        b, _, w, h = x.shape
        p = x.flatten(2)
        p = p.unsqueeze(0)
        p = p.transpose(0, 3)
        p = p.squeeze(3)
        e = self.linear(p)
        x = p + e

        x = self.tr(x)
        x = x.unsqueeze(3)
        x = x.transpose(0, 3)
        x = x.reshape(b, self.c2, w, h)
        return x

class C3TR(C3):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)
        self.m = TransformerBlock(c_, c_, 4, n)

And I just put it as the last part of the backbone instead of a C3 block.

backbone:
  # [from, number, module, args]
  [[-1, 1, Focus, [64, 3]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 9, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 1, SPP, [1024, [5, 9, 13]]],
   [-1, 3, C3TR, [1024, False]],  # 9    <---- here is my modifcation
  ]

I conducted experiments on 2 Nvidia GTX 1080Ti cards, where depth_multiple and width_multiple are the same as Yolov5s. Here are my experimental results with img-size 640. For convenience I named the method in this issue as Yolov5TRs.

Model Params GFLOPs
Yolov5s 7266973 17.0
Yolov5TRs 7268765 16.8
Model Dataset TTA mAP@.5 mAP@.5:.95 Speed (ms)
Yolov5s coco (val) N 0.558 0.365 4.4
Yolov5TRs coco (val) N 0.568 0.363 4.4
Yolov5s coco (test-dev) N 0.559 0.365 4.6
Yolov5TRs coco (test-dev) N 0.567 0.365 4.5
Yolov5s coco (test-dev) Y 0.568 0.378 12.0
Yolov5TRs coco (test-dev) Y 0.571 0.375 11.0

We can see that Yolov5TRs get higher scores in mAP@0.5 with a faster speed. (I'm not sure why my results of Yolov5s are different from which shown in README. The model was downloaded from release v4.0) When depth_multiple and width_multiple are set to larger numbers, C3TR should be more lightweight than C3. Since I do not have so much time on it and my machine is not very strong, I did not run experiments on M, L and X. Maybe someone could conduct the future experiments:smile:

github-actions[bot] commented 3 years ago

👋 Hello @dingyiwei, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

glenn-jocher commented 3 years ago

@dingyiwei hey very cool!! The updates seem a bit faster with a bit less FLOPS... I'll have to look at this a little more in depth, but very quickly I would add that the C3TR module you placed in at the end of the backbone will primarily effect large objects, so many of the smaller objects may not be significantly affected by the change.

To give a bit of background: the largest C3 modules, like the 1024-channel one you replaced are responsible for most of the model parameter count, but execute very fast (due to the small 20x20 feature grid they sample), whereas the the earliest C3 modules like 1-P2/4 and 2-P2/8 have very few parameters, but are slow to execute due to their very small stride and large grid, i.e. 160x160 and 80x80.

So it would be interesting to see the effects of replacing the 256 and 512 channel C3 modules as well has.

glenn-jocher commented 3 years ago

@dingyiwei just checked, we have a multigpu instance freeing up soon, I think we can add a few C3TR runs to the queue to experiment further. Could you submit a PR with your above updates please?

glenn-jocher commented 3 years ago

@dingyiwei I pasted your modules into common.py and added C3TR to the modules list in yolo.py, and I can build a model successfully, but my numbers look a little different than yours:

default YOLOv5s

Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPS

[-1, 3, C3TR, [1024, False]], # 9

Model Summary: 276 layers, 6686013 parameters, 6686013 gradients, 16.6 GFLOPS

My full C3TR module (with only self.m different):

class C3TR(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(C3TR, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = TransformerBlock(c_, c_, 4, n)

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))

EDIT: had to add C3TR in a second spot in yolo.py, now I match your numbers.

Model Summary: 286 layers, 7277885 parameters, 7277885 gradients, 16.8 GFLOPS
Joker316701882 commented 3 years ago

@dingyiwei @glenn-jocher Applying dropout can greatly improve Transformer's performance, so I did a slight modify on the TransformerLayer and observed improvements on my own model on COCO val(based on YOLOv5L). I'm not very familiar with the standard Transformer, but most codes I saw about Transformer apply Dropout, so the following TransformerLayer could be a better implementation.

class TransformerLayer(nn.Module):
     def __init__(self, c, num_heads):
         super().__init__()

         self.ln1 = nn.LayerNorm(c)
         self.q = nn.Linear(c, c, bias=False)
         self.k = nn.Linear(c, c, bias=False)
         self.v = nn.Linear(c, c, bias=False)
         self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
         self.ln2 = nn.LayerNorm(c)
         self.fc1 = nn.Linear(c, c, bias=False)
         self.fc2 = nn.Linear(c, c, bias=False)
         self.dropout = nn.Dropout(0.1)
         self.act = nn.ReLU(True)

     def forward(self, x):
         x_ = self.ln1(x)
         x = self.dropout(self.ma(self.q(x_), self.k(x_), self.v(x_))[0]) + x
         x_ = self.ln2(x)
         x_ = self.fc2(self.dropout(self.act(self.fc1(x_))))
         x = x + self.dropout(x_)
         return x
dingyiwei commented 3 years ago

Hi @Joker316701882 , actually I removed dropout at the beginning since there's no dropout in this codebase 🤣. I'll have a try now on VOC. FYI, I tried nn.SiLU before in self.fc2(self.act(self.fc1)) but got a worse result. So you can also run experiments without activation functions in TransformerLayer.

NanoCode012 commented 3 years ago

Hello @dingyiwei , may I ask if you trained with multi-gpu option or single-gpu? I saw that you wrote "2 Nvidia GTX 1080Ti cards" in your first post.

The reason I'm asking is that I set 2 GPU & 4GPU runs for the 5m/5l using your backbone and got an error around the 110-120th epoch.

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

Do you perhaps have any clue about this error? I also recall that glenn was planning to do multi-gpu training as well on this branch. Could you tell me if you run into any errors as well?

dingyiwei commented 3 years ago

Hi @NanoCode012 , I ran my experiments by python train.py --data coco.yaml --cfg yolotrs.yaml --weights '' --batch-size 64. I saw 2 of my GPUs worked so I just left them running. Thus I've never met this problem since I didn't use DDP mode.

I guess the problem could be caused by nn.MultiheadAttention, according to the error message. Its forward has 2 outputs, attn_output and attn_output_weights, where the first one is what we need:

    def forward(self, x):
        x_ = self.ln1(x)
        x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x   # <---- here we only use the first output
        x = self.ln2(x)
        x = self.fc2(self.fc1(x)) + x
        return x

I'm going to check it when my last experiment finished.

NanoCode012 commented 3 years ago

Hello @dingyiwei, I see!

Have you tried just using a single GPU for training instead? From my test on COCO, DP didn't actually speed up training. Maybe you could run two training instead of one :)

I found an issue https://github.com/pytorch/pytorch/issues/26698 which talks about the incompatibility of nn.MultiheadAttention with DDP. I will try their proposed solution below. The author there did mention that it introduced another bug, but I'll have to try to test it out. I guess we will need a PR to DDP if we decide to include the transformer in the backbone.

passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel` 

Another note: this can introduce some overhead in DDP https://pytorch.org/docs/stable/notes/ddp.html


Forward Pass: The DDP takes the input and passes it to the local model, and then analyzes the output from the local model if find_unused_parameters is set to True.... Note that traversing the autograd graph introduces extra overheads, so applications should only set find_unused_parameters to True when necessary.
dingyiwei commented 3 years ago

Hi @Joker316701882 , I tested dropout and dropout+act on VOC (based on yolov5s + Transformer), but it seems no obvious promotion. May I ask for your experimental results about dropout?

@glenn-jocher @zhiqwang @NanoCode012 And I found a MISTAKE in my PR #2333 : in a classic Transformer layer, the 2nd LayerNorm should be placed in the 2nd residual block (as described in Joker316701882's comment), according to ViT. But I executed x = self.ln2(x) individually...

image

Fortunately so far I didn't feel any damages or benefits from the mistake, but I'm not sure how it will affect on larger models.

jaqub-manuel commented 3 years ago

Hey, @dingyiwei, I applied your addition to Custom dataset and there was a slight increase, 0.005. Why did you apply after SPP (1024-channel), Could you explain little more for YOLOv5? I applied it before SPP (512-channel) but got lower results. Thanks...

dingyiwei commented 3 years ago

Hi @jaqub-manuel , usually components with self-attention mechanism e.g., Non-local and GCNet, are used for extract global information. So I just put Transformer at the last part of the backbone intuitively.

@glenn-jocher is trying to put Transformer in different stages of the backbone and in the head of Yolov5. Maybe his experiments could give us some ideas.

glenn-jocher commented 3 years ago

@dingyiwei @jaqub-manuel I started an experiment run but got sidetracked earlier in the week. I discovered some important information though. It seems like the transformer block uses up a lot of memory. I created a transformer branch: https://github.com/ultralytics/yolov5/tree/transformer

And tried to train 8 models, 1 default yolov5m.yaml and then 7 transformer models. Each of the transformer models replaces C3 with C3TR in the location mentioned, i.e. only in layer 2, or only in backbone, etc. https://github.com/ultralytics/yolov5/tree/transformer/models

Screen Shot 2021-03-05 at 2 14 32 PM

Unfortunately all of the 7 models except the layer 9 model hit CUDA OOM, so I cancelled the training to think a bit. The layers that use the least amount of CUDA memory are the largest stride layers (P5/32), like layer 9, so this may be why @dingyiwei was using it for the test. I think maybe layer 9 is then the best place to implement, as it uses less memory, and affects the whole head. So all I've really learned is that the default test @dingyiwei ran is probably the best for producing a trainable model that doesn't eat too many resources.

@dingyiwei can you update the PR with a fix for the mistake in https://github.com/ultralytics/yolov5/issues/2329#issuecomment-790372849, and then I'll train a YOLOv5m model side by side with the layer 9 replacement, and maybe I can try a layer 9 + P5 head replacement also. The P5 layer itself is the largest mAP contributor at 640 resolution, so its not all bad news that we can only apply the transformer to that layer to minimize memory usage.

NanoCode012 commented 3 years ago

Hello, I finished most of my trainings (2 left) on testing the Transformer. I noted down my results in wandb. It's my first time using it, so I hope I'm doing it right.

Transformer runs on wandb

My observations were that the Transformer runs (denoted by tr) produced mixed results. They weren't as clear-cut as in @dingyiwei 's first post. Also, the experiment with 2nd LayerNorm fix 4_5trmv2 got lower results than without the fix 4_trmv1.

image

Edit: Added table here for backup

Name batch_size test map 0.5 test map 0.5..0.95 pyco map 0.5 pyco map 0.5..0.95
1_5m 64 62 42.3 62.7 43.6
1_5trm 64 61.9 42.2 62.2 43.4
4_5mv2 256 62.5 42.6 63.3 43.9
4_5trmv1 256 62.2 42.6 62.9 43.8
4_5trmv2 256 62.1 42.2 62.8 43.4
1_5lv2 48 64 45 64.7 46.2
1_5trl 48 65.4 45.7 66 46.9
4_5trlv3 128 65.3 45.8 66 47
1_5trx 32 - - - -
dingyiwei commented 3 years ago

Inspired by @NanoCode012 , I tried to remove both LayerNorm layers of Transformer in YOLOv5s, and got a surprise:

Model Dataset TTA mAP@.5 mAP@.5:.95
Yolov5s coco (val) N 0.558 0.365
Yolov5s + Tr coco (val) N 0.568 0.363
Yolov5s + Tr(without LN) coco (val) N 0.571 0.366

Will run on test-dev and upload the model later.

UPDATE:

Experimental results:

Model Dataset TTA mAP@.5 mAP@.5:.95
Yolov5s coco (val) N 0.558 0.365
Yolov5s + Tr coco (val) N 0.568 0.363
Yolov5s + Tr(without LN) coco (val) N 0.571 0.366
Yolov5s coco (test-dev) N 0.559 0.365
Yolov5s + Tr coco (test-dev) N 0.567 0.365
Yolov5s + Tr(without LN) coco (test-dev) N 0.569 0.366
Yolov5s coco (test-dev) Y 0.568 0.378
Yolov5s + Tr coco (test-dev) Y 0.571 0.375
Yolov5s + Tr(without LN) coco (test-dev) Y 0.573 0.377

Here is the implementation:

class TransformerLayer(nn.Module):
    def __init__(self, c, num_heads):
        super().__init__()

        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)

    def forward(self, x):
        x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
        x = self.fc2(self.fc1(x)) + x
        return x

New model is here.

Joker316701882 commented 3 years ago

@dingyiwei According to your posted results, the mAP@0.5 improved but mAP@.05:.95 remains unchanged. Does it mean mAP@.75 actually dropped?

dingyiwei commented 3 years ago

Hi @Joker316701882 , I didn't record mAP@.75 in those experiments. According to @glenn-jocher 's explanation, C3TR at the end of the backbone could affect on large objects. So I guess mAP@.95 would drop and mAP@.75 might be unchanged.

jaqub-manuel commented 3 years ago

Dear @dingyiwei , could you upload new model or share link or code, then i will try for my custom dataset.

dingyiwei commented 3 years ago

Hi everyone, I updated the experimental results, the implementation and the trained model of C3TR without LN in this comment the day before yesterday. It seems editing a comment would not trigger a notification or an email, so I just remind you about that.

glenn-jocher commented 3 years ago

@dingyiwei very interesting result! I think layernorm() is a pretty resource intensive operation (at least when compared to batchnorm). Did removing it reduce the training memory requirements?

dingyiwei commented 3 years ago

Hi @glenn-jocher , in my experiments yes. For YOLOv5s + TR, gpu_mem showed 6.63G, while for YOLOv5s + TR(without LN), gpu_mem showed 6.61G.

glenn-jocher commented 3 years ago

@dingyiwei thanks for the info, so not much of a change in memory from removing layernorm().

zachluo commented 3 years ago

hi, all, did anyone try position embedding? It seems like the transformer helps classification rather than localization according to the results of AP@0.5 and AP@0.5:0.95.

glenn-jocher commented 3 years ago

@dingyiwei I'm working on getting the Transformer PR https://github.com/ultralytics/yolov5/pull/2333 merged, I merged master to bring it up to date with the latest changes, and I noticed that the TransformerLayer() module in the PR is different from your most recent in https://github.com/ultralytics/yolov5/issues/2329#issuecomment-800045863, which do you think we should we use for the PR? Let me know, thanks!

glenn-jocher commented 3 years ago

@dingyiwei also we should add a one-line comment for each of the 3 new modules that explains a bit or cites a source if you can please. I've done this with C3TR(), but left the other two up to you.

Once we have these updates and decide on TransformerLayer() then I can merge the PR. Thanks!

Alex-afka commented 3 years ago

嗨@Glenn-jocher在我的实验中是的。对于YOLOv5s+TR,gpu_mem显示6.63G,而YOLOv5S+TR(无LN),gpu_mem显示6.61G。

how train this new module. can you show me the detail about this? you train with pretrain ?or train from scratch?

glenn-jocher commented 3 years ago

@Alex-afka PR https://github.com/ultralytics/yolov5/pull/2333 provides a transformer version of YOLOv5s in models/yolov5s-transformer.yaml. You train it the same way as any other model. You're free to provide pretrained weights or train from scratch. If you provide pretrained weights all layers with matching names and shapes will transfer over.

dingyiwei commented 3 years ago

Thank @glenn-jocher , I prefer to the newest version (Transformer without LayerNorm) since it does outperform on COCO. Will update the PR later, including the code and the comments.

Alex-afka commented 3 years ago

i find add transformer can get better result in small object of my own data .what about your experiment in coco.

发自我的iPhone

------------------ Original ------------------ From: Ding Yiwei @.> Date: Wed,Mar 31,2021 6:03 PM To: ultralytics/yolov5 @.> Cc: Alex-afka @.>, Mention @.> Subject: Re: [ultralytics/yolov5] Apply Transformer in the backbone (#2329)

glenn-jocher commented 3 years ago

@dingyiwei great thanks! Once you update the PR I will review.

NanoCode012 commented 3 years ago

For multi-gpu runs, we need to pass find_unused_params=True to DDP. Should this be put in comments too or in the Multi-GPU guide?

dingyiwei commented 3 years ago

Hi @NanoCode012 , in my latest update #2333 I checked whether a MA layer is in the model before DDP. Could you also have a review?

NanoCode012 commented 3 years ago

@dingyiwei, That's a nice and simple solution!

glenn-jocher commented 3 years ago

@all PR #2333 is merged now, providing official Transformer module support in TransformerLayer, TransformerBlock (using multiple TransformerLayers), and C3TR (C3 module with Conv2d() sequence replaced by TransformerBlock).

Thanks to @dingyiwei for the excellent work!

EDIT: reopening issue to continue convo, was auto-closed on PR merge.

Alex-afka commented 3 years ago

Hi @NanoCode012 , in my latest update #2333 I checked whether a MA layer is in the model before DDP. Could you also have a review? hi i want add transformer in yolov4 .can you give me some advices ? i want add after the sppblock . is it an good idear in this?

ShirleyHe2020 commented 3 years ago

🚀 Feature

Transformer is popular in NLP, and now is also applied on CV. I added C3TR just by replacing the sequential self.m in C3 with a Transformer block, which could reduce GFlOPs and make Yolo achieve a better result.

Motivation

  • Dosovitskiy et al. proposed ViT
  • Facebook applied Transformer on object detection as an encoder
  • So I thought Transformer could make yolo better

Pitch

I add 3 classes in https://github.com/dingyiwei/yolov5/blob/Transformer/models/common.py :

class TransformerLayer(nn.Module):
    def __init__(self, c, num_heads):
        super().__init__()

        self.ln1 = nn.LayerNorm(c)
        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.ln2 = nn.LayerNorm(c)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)

    def forward(self, x):
        x_ = self.ln1(x)
        x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x
        x = self.ln2(x)
        x = self.fc2(self.fc1(x)) + x
        return x

class TransformerBlock(nn.Module):
    def __init__(self, c1, c2, num_heads, num_layers):
        super().__init__()

        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)
        self.linear = nn.Linear(c2, c2)
        self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
        self.c2 = c2

    def forward(self, x):
        if self.conv is not None:
            x = self.conv(x)
        b, _, w, h = x.shape
        p = x.flatten(2)
        p = p.unsqueeze(0)
        p = p.transpose(0, 3)
        p = p.squeeze(3)
        e = self.linear(p)
        x = p + e

        x = self.tr(x)
        x = x.unsqueeze(3)
        x = x.transpose(0, 3)
        x = x.reshape(b, self.c2, w, h)
        return x

class C3TR(C3):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)
        self.m = TransformerBlock(c_, c_, 4, n)

And I just put it as the last part of the backbone instead of a C3 block.

backbone:
  # [from, number, module, args]
  [[-1, 1, Focus, [64, 3]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 9, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 1, SPP, [1024, [5, 9, 13]]],
   [-1, 3, C3TR, [1024, False]],  # 9    <---- here is my modifcation
  ]

I conducted experiments on 2 Nvidia GTX 1080Ti cards, where depth_multiple and width_multiple are the same as Yolov5s. Here are my experimental results with img-size 640. For convenience I named the method in this issue as Yolov5TRs.

Model Params GFLOPs Yolov5s 7266973 17.0 Yolov5TRs 7268765 16.8 Model Dataset TTA mAP@.5 mAP@.5:.95 Speed (ms) Yolov5s coco (val) N 0.558 0.365 4.4 Yolov5TRs coco (val) N 0.568 0.363 4.4 Yolov5s coco (test-dev) N 0.559 0.365 4.6 Yolov5TRs coco (test-dev) N 0.567 0.365 4.5 Yolov5s coco (test-dev) Y 0.568 0.378 12.0 Yolov5TRs coco (test-dev) Y 0.571 0.375 11.0 We can see that Yolov5TRs get higher scores in mAP@0.5 with a faster speed. (I'm not sure why my results of Yolov5s are different from which shown in README. The model was downloaded from release v4.0) When depth_multiple and width_multiple are set to larger numbers, C3TR should be more lightweight than C3. Since I do not have so much time on it and my machine is not very strong, I did not run experiments on M, L and X. Maybe someone could conduct the future experiments😄

Hi Dingyiwei , thanks for your work. I have 2 questions: 1 Does C3TR here helps resolve class-imbalance problems? 2 How C3TR work? ( any suggestions help me understand this module? )

dingyiwei commented 3 years ago

Hi @ShirleyHe2020 ,

  1. Till now no evidence shows relevances between C3TR/Transformer (self-attention mechanism) and class-imbalance problem.
  2. The biggest difference between C3 and C3TR is Transformer. Maybe the papers I listed in my motivation are helpful to you.
Alex-afka commented 3 years ago

Hi @ShirleyHe2020 ,

  1. Till now no evidence shows relevances between C3TR/Transformer (self-attention mechanism) and class-imbalance problem.
  2. The biggest difference between C3 and C3TR is Transformer. Maybe the papers I listed in my motivation are helpful to you.

hi i have a question. i want to use transformer in pafpn not in backbone. can you give some advice.i do not know where can get better to add transfomer in pafpn.

Alex-afka commented 3 years ago

Hi @NanoCode012 , in my latest update #2333 I checked whether a MA layer is in the model before DDP. Could you also have a review?

hi i have a question. i want to use transformer in pafpn not in backbone. can you give some advice.i do not know where can get better to add transfomer in pafpn.

github-actions[bot] commented 3 years ago

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.

Access additional YOLOv5 🚀 resources:

Access additional Ultralytics ⚡ resources:

Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!

Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!

guyiyifeurach commented 3 years ago

Hi, I downloaded the model of above, but when train, it turns out to be: image And I wonder why occur this question. Can you tell me the answer? @dingyiwei @glenn-jocher Thank you!

glenn-jocher commented 3 years ago

@guyiyifeurach there are no transformer pretrained weights, but you can start from the normal pretrained weights instead. To train a YOLOv5s transformer model in our Colab notebook for example: https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb

# Train YOLOv5s on COCO128 for 3 epochs
!python train.py --img 640 --batch 16 --epochs 3 --data coco128.yaml --weights yolov5s.pt --cfg yolov5s-transformer.yaml
qiy20 commented 2 years ago

This dimensional operation will change the batch_size dim? I don't understand why we're doing this?

# b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c 
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)

I think the right operation is:

p = x.flatten(2).transpose(1, 2)

@dingyiwei

dingyiwei commented 2 years ago

Hi @qiy20 , I forgot why to write this piece of code😂. Feel free to update it if you confirm it is correct.

glenn-jocher commented 2 years ago

@qiy20 @dingyiwei would the right simplification be this?

# b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c 
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)

# simplied
p = x.flatten(2).transpose(0, 2)
dingyiwei commented 2 years ago

@glenn-jocher I think no..

# b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c 
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)

# b,c,w,h-->b,c,wh-->b,wh,c
p = x.flatten(2).transpose(1, 2)

# b,c,w,h-->b,c,wh-->wh,c,b
p = x.flatten(2).transpose(0, 2)

I thought my original idea was to keep c after b. transpose once cannot do that.

An alternative is adding batch_first=True in MultiheadAttention, then we could

p = x.flatten(2).transpose(1, 2)
return self.tr(p + self.linear(p)).transpose(1, 2).reshape(b, self.c2, w, h)

I'll verify it with experiments. Let me know if you get different ideas :)

glenn-jocher commented 2 years ago

@dingyiwei ok I think I've got it. Yes are right, transpose is acting unexpectedly. I had to use permute, but this seems to result in a 2x speedup:

import torch

x= torch.rand(16,3,80,40)
p1 = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
p2 = x.flatten(2).permute(2,0,1)
print(torch.allclose(p1,p2))  # True

%timeit x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
# 5.36 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit x.flatten(2).permute(2,0,1)
# 2.83 µs ± 62 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
glenn-jocher commented 2 years ago

@dingyiwei if batch_first=True profiles faster that might be the better solution.

dingyiwei commented 2 years ago

@glenn-jocher Training time and inference time appear no difference among the current code, permute and batch_first=True.

I ran 10 epochs for each solution with python train.py --data data/coco.yaml --cfg models/hub/yolov5s-transformer.yaml --weights '' --batch-size 32 --epochs 10 and tested them with python val.py --data data/coco.yaml --weights runs/train/exp/weights/best.pt --img 640 on one 2080ti.

Model Training time (hour) Inference time (ms)
Current 5.053 2.6
Permute 5.053 2.6
Batch first 5.053 2.6

But permute is more elegant and readable, I'll submit a pull request for it.

p = x.flatten(2).permute(2, 0, 1)
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
glenn-jocher commented 2 years ago

@dingyiwei understood! Yes please submit a PR for permute().