lyuwenyu / RT-DETR

[CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥
Apache License 2.0
2.61k stars 303 forks source link

TransformerDecoderLayer用batchnorm1D会报CUDA700错误 #432

Open sdreamforchen opened 2 months ago

sdreamforchen commented 2 months ago

class RepBN(nn.Layer): def init(self, channels): super(RepBN, self).init() self.alpha = self.create_parameter( shape=[1], attr=ParamAttr(initializer=Constant(value=1.)), dtype="float32") self.bn = nn.BatchNorm1D(channels)

def forward(self, x):
    x = x.transpose([0, 2, 1])
    x = self.bn(x) + self.alpha * x
    x = x.transpose([0, 2, 1])
    return x

我将此替换TransformerDecoderLayer里面的layernorm,替换前两个没问题,看了内存,也正常,无明显内存增加。但是替换MLP后的layernorm后,就算是batchsize=8也会报错(原为32,使用了16GB/24GB)。