zhihou7 / BatchFormer

CVPR2022, BatchFormer: Learning to Explore Sample Relationships for Robust Representation Learning, https://arxiv.org/abs/2203.01522
242 stars 20 forks source link

The code and training for BatchFormerv2 #8

Closed OBVIOUSDAWN closed 2 years ago

OBVIOUSDAWN commented 2 years ago

Thank you for your work on BatchFormer. I have read and been inspired by the BatchFormerv2 paper. As a beginner, I am not sure if the dual stream training model is equivalent to two identical networks to interact transforblock or just one network but training with BatchFormerv2 and without in turn at each step, and finally freezing the module at validation. If the latter is implemented the backpropagation problem will be easy to solve. Also I would like to know when the code for BatchFormerv2 will be released. Thank you very much. Looking forward to your reply!

zhihou7 commented 2 years ago

hi @OBVIOUSDAWN, Thanks for your interest. It is the latter one: it is just one network (two streams) but training with BatchFormerv2 and without in turn at each step. However, it is not freezing the module at validation but removes the BatchFormerV2 module at validation.

The implementation is very easy.

def batch_former_v2(x, encoder , is_training , is_first_layer):
  # x: input features with the shape (B, N, C).
  # encoder: TransformerEncoderLayer(C, nhead, C, 0.5, batch_first=False) if not is_training:
  return x orig_x = x
  if not is_fist_layer:
      orig_x , x = torch.split(x, len(x)//2)
  x = encoder(x)
  x = torch.cat([orig_x, x], dim=0) 

  return x

You should also expand the labels in the batch dimension as the following example shows,

y = torch.cat([y, y], dim=0)

In BatchFormer, $y$ is the one-hot/multi-hot label, while $y$ might be bounding boxes and categories in BatcmFormerV2. You can implement it according to the code in BatchFormer. In domain_generalization, you can find the modification as follow,

diff baseline.py <(curl https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/master/examples/domain_generalization/image_classification/baseline.py)

Code will be released once the paper is accepted. If you are in a hurry for the code, you can email me and I will send you the code based on Deformable-DETR.

OBVIOUSDAWN commented 2 years ago

l have email and i would be appreciate for you reply.Thanks for your help.