LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701
MIT License
785 stars 36 forks source link

zero_module #24

Open fendi001 opened 7 months ago

fendi001 commented 7 months ago

@LTH14 I have not yet quite understood that why use zero_module in ResNet and SimpleMLP , it just zero out directly the tensors (input_rep_tensor + time_tensor+condition_tensor) ( torch.all(h==0)=> True), which may hurt or benefit the representation generation from noised original representation in diffusion process? what's the special usages and purposes of zero_module in the following functions?

ResNet

self.out_layers = nn.Sequential( nn.LayerNorm(mid_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( nn.Linear(mid_channels, channels, bias=True) ), )

SimpleMLP

self.out = nn.Sequential (nn.LayerNorm(model_channels, eps=1e-6), nn.SiLU(), zero_module(nn.Linear(model_channels, out_channels, bias=True))

LTH14 commented 7 months ago

These zero modules just initialize the parameters to zero. It should not largely affect the performance.