lukemelas / PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
770 stars 124 forks source link

About torch.no_gard() #24

Open TianshengSun opened 2 years ago

TianshengSun commented 2 years ago

Hi, Thanks for this implementation. I saw the parameters of nn.Linear() are set to no_gard() in models.py Line:139.

    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)
        self.apply(_init)
        nn.init.constant_(self.fc.weight, 0)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02)  # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02)
        nn.init.constant_(self.class_token, 0)

Does this mean this pro only supports eval? These parameters should be trainable if I want train ViT on my own dataset?