microsoft / Cream

This is a collection of our NAS and Vision Transformer work.
MIT License
1.66k stars 225 forks source link

Bug about weight sharing in AutoFormer #232

Open variant-star opened 6 months ago

variant-star commented 6 months ago

https://github.com/microsoft/Cream/blob/4a13c4091e78f9abd2160e7e01c02e48c1cf8fb9/AutoFormer/model/module/qkv_super.py#L72-L77 I think, there's something wrong in the way weight sharing is done here. I think this code should be:

    N = weight.size(0) // 3
    sample_weight = torch.cat([sample_weight[i*N:i*N+sample_out_dim//3, :] for i in range(3)], dim=0)

To be more intuitive, I drew a schematic diagram to represent the way 4 and 5 heads SA is shared with Linear.weight.

Snipaste_2024-03-28_22-05-19

Maybe I misunderstood the implementation here, can you help check it?