ckczzj / PDAE

Official PyTorch implementation of PDAE (NeurIPS 2022)
271 stars 20 forks source link

Some questions about implementation of unet_shift #3

Open xinyangATK opened 1 year ago

xinyangATK commented 1 year ago

Thank you so much for releasing your code and I have some questions while reproducing your work. In the forward() function of class ResBlockShift(TimestepZBlock), the out_rest(h) seems set h to zero which doesn't make emb_z effrctive. Is there any problems in this module?

    # Adaptive Group Normalization
    out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
    scale, shift = torch.chunk(emb_out, 2, dim=1)
    z_scale, z_shift = torch.chunk(emb_z_out, 2, dim=1)
    h = (1. + z_scale) * (out_norm(h) * (1. + scale) + shift) + z_shift
    h = out_rest(h)

    return self.skip_connection(x) + h
ckczzj commented 1 year ago

Thanks for your attention!

Dou you mean the zero_module here? https://github.com/ckczzj/PDAE/blob/fbba0355634861196aed8b80b9ba4948ed210ab9/model/module/module.py#L362-L364

It is just a zero-initialization of the output conv layer. The zero-initialization makes the residual block work like an identity function in the beginning of training, which is a commonly-used trick for stable training.

Although the parameters are initilized as zero, their gradient still exist. After the first update of the network, they will be almost none-zero. Recent work ControlNet have similar issues.

xinyangATK commented 1 year ago

Thank you for your patient answer!

This really solved my confusion about this module.