THU-MIG / RepViT

RepViT: Revisiting Mobile CNN From ViT Perspective [CVPR 2024] and RepViT-SAM: Towards Real-Time Segmenting Anything
https://arxiv.org/abs/2307.09283
Apache License 2.0
799 stars 60 forks source link

Question about "Convert a training-time RepViT into the inference-time structure" with replace_batchnorm. #39

Closed Topdu closed 8 months ago

Topdu commented 8 months ago

@jameslahm Hi,

Thanks for your great work !! The model outputs obtained using replace_batchnorm are inconsistent, which is puzzling, as to how to get the correct results:

    model = create_model(model_name, num_classes=1000)
    y = model(inputs)
    print(y)
    utils.replace_batchnorm(model)
    y = model(inputs)
    print(y)

result: tensor([[ 0.0716, -0.1728, -0.6425, ..., -0.1067, -1.0730, -0.2081], [-0.0716, 0.1728, 0.6425, ..., 0.1067, 1.0730, 0.2081]]) tensor([[-0.0643, -0.4197, 0.1459, ..., -0.0541, -0.1685, 0.1784], [-0.0643, -0.4197, 0.1459, ..., -0.0541, -0.1685, 0.1784]])

jameslahm commented 8 months ago

Do the modules in the model have a correct implementation of the fuse function? Please note that the BatchNormalization in the module without fuse function will be simply replaced with nn.Identity.

Topdu commented 8 months ago

Thank you for the reply, the code used here is cloned from your repo without any changes.

Topdu commented 8 months ago

I just modified speed_gpu.py as follows:

from argparse import ArgumentParser

parser = ArgumentParser()

parser.add_argument('--model', default='repvit_m0_9', type=str)
parser.add_argument('--resolution', default=224, type=int)
parser.add_argument('--batch-size', default=2, type=int)

if __name__ == "__main__":
    args = parser.parse_args()
    model_name = args.model
    batch_size = args.batch_size
    resolution = args.resolution
    torch.cuda.empty_cache()
    inputs = torch.randn(batch_size, 3, resolution,
                            resolution)
    model = create_model(model_name, num_classes=1000)
    y = model(inputs)
    print(y)
    utils.replace_batchnorm(model)
    y = model(inputs)
    print(y)
    # model.to(device)
    # model.eval()
    # throughput(model_name, model, device, batch_size, resolution=resolution)
jameslahm commented 8 months ago

Thanks. It seems that the statistics in BatchNormalization have changed in the first forward. Could you please place model.eval() before the first y=model(inputs).

Topdu commented 8 months ago

It's working! Thanks!