mhamilton723 / STEGO

Unsupervised Semantic Segmentation by Distilling Feature Correspondences
MIT License
711 stars 142 forks source link

Why are losses for linear probe and cluster probe combined for model training? #57

Closed codingbutstillalive closed 1 year ago

codingbutstillalive commented 1 year ago

Based on the paper, the linear probe loss shall not be used for Stego's model training. But as far as I can see, both the linear probe loss as well as the cluster loss contribute to the overall loss used in backpropagation. Please explain.

axkoenig commented 1 year ago

Good question, I was also a bit puzzled by this at first sight of the code by I think here's what's going on:

  1. first of all it's important to understand that they set up 3 different optimizers here. one for the "net", which is the DINO backbone + segmentation head, and one for the linear probe as well as cluster lookup.
  2. in the train loop they first pass the batch through the net here and then compute a bunch of losses on that.
  3. then it's important to note that a few lines down they are detaching the code (i.e. the output of the segmentation head) from the computation graph that lead through the "net". Hence any operation done on the tensor after the detachment can not influence the gradients of the "net" anymore.
  4. now they want to also learn the weights for the linear probe layer as well as the cluster centers for the cluster probe case. to do that they pass the detached code from step (3) into the linear probe and cluster probe modules and obtain a linear_loss and a cluster_loss. both of these losses now have their own computation graph: the graph of linear_loss depends on the learnable parameters of the linear layer and, in turn, the cluster_loss depends only on the cluster parameters.
  5. now they are adding these two losses to the already existing loss which already has its own, separate computation graph. the cool thing now is that these three separate computation graphs stay "in tact" after the + operation and a grad_fn=<AddBackward0> is added to the computation graph (more details here). e.g., if you print out the loss variable you'll see tensor(3.3079, grad_fn=<AddBackward0>).
  6. when manually calling self.manual_backward(loss) here the autograd engine can track back which parts of the overall system the gradients came from and calculate the backward pass for the entire graph. the gradients of the linear and cluster parameters hence don't influence the updates of the weights of the segmentation head.
  7. now the only thing left is the optimizer steps in the lines below that use the optimizers from step (1)
axkoenig commented 1 year ago

Let me know if this helps. Would be interested in hearing your thoughts on this

codingbutstillalive commented 1 year ago

Thank you very much. Both for the efforts of explaining it so well in detail, as well as for the additional insights. I totally appreciate that. This explanation resolves my confusion. It is a nice piece of engineering, although a few more code comments could have helped before releasing the code, I suppose ;)

mhamilton723 commented 1 year ago

@axkoenig Thank you for that great description of the technique! Couldn't have said it better myself, and appreciate your and @codingbutstillalive's patience :).