PRBonn / lidar-bonnetal

Semantic and Instance Segmentation of LiDAR point clouds for autonomous driving
http://semantic-kitti.org
MIT License
959 stars 206 forks source link

Stop-gradients in skip-connections #99

Closed kazuto1011 closed 1 year ago

kazuto1011 commented 2 years ago

Thank you for sharing your codes. I found that all backbones call detach() in skip-connections. For example:

https://github.com/PRBonn/lidar-bonnetal/blob/5a5f4b180117b08879ec97a3a05a3838bce6bb0f/train/backbones/squeezesegV2.py#L156

Could you tell me where this idea is from? I cannot find the corresponding part in the official SqueezeSeg/SqueezeSegV2. Besides, I'm concerned that the first detach() in SqueezeSegV2 is not what is expected.

https://github.com/PRBonn/lidar-bonnetal/blob/5a5f4b180117b08879ec97a3a05a3838bce6bb0f/train/backbones/squeezesegV2.py#L170-L174

skip_in is detached and never referenced afterward so that the self.conv1b layer never receives gradients to update themselves. Here is the quick check I did.

# from squeezesegV2.py
encoder = Backbone(encoder_params)
decoder = Decoder(decoder_params, None)

x = torch.randn(1, 5, 64, 512)
y = decoder(*encoder(x))
y.sum().backward()

for name, p in encoder.named_parameters():
    if p.grad is None:
        print(name, "is None")

The above snippet gives:

conv1b.0.weight is None
conv1b.0.bias is None
conv1b.1.weight is None
conv1b.1.bias is None
jbehley commented 1 year ago

oops, this issue is quite old. sorry.

As far as I remember: detaching or not might not make a big difference in performance, however, Andres mentioned that he had better experiences when the whole network is learned (instead of just skipping the middle part), when he adds the detach. But I'm not sure if there is a better explanation.

kazuto1011 commented 1 year ago

Thank you for your response. Okay I understand you adopted detach() empirically. I totally thought it might bring inefficiency or instability when stopping gradient flows in skip connections.