def decode(self, input, feature, src_mask, tgt_mask):
# main process of transformer decoder.
x = self.embedding(input)
x = self.positional_encoding(x)
# origin transformer layers
for i, layer in enumerate(self.layers):
x = layer(x, feature, src_mask, tgt_mask)
# cls head
for layer in self.cls_layer:
cls_x = layer(x, feature, src_mask, tgt_mask)
cls_x = self.norm(cls_x)
# bbox head
for layer in self.bbox_layer:
bbox_x = layer(x, feature, src_mask, tgt_mask)
bbox_x = self.norm(bbox_x)
return self.cls_fc(cls_x), self.bbox_fc(bbox_x)