OpenGVLab / DiffRate

[ICCV 23]An approach to enhance the efficiency of Vision Transformer (ViT) by concurrently employing token pruning and token merging techniques, while incorporating a differentiable compression rate.
78 stars 7 forks source link

How does the gradient flow into the FLOP loss? #1

Closed kaikai23 closed 1 month ago

kaikai23 commented 1 year ago

Congratulations on the great work! It inspires me very much.

I seem to understand how it how the gradient is passed through the classification loss in the paper, thanks to the re-parameterization trick and the ST estimator, but I don't fully understand how the FLOPs loss (i.e., L_f) passes gradients to alpha_p and alpha_m? Did you use any tool to make the computation of FLOPs differentiable with respect to the pruning rate?

ChenMnZ commented 1 year ago

Hi, thanks for your interest!

To begin, we compute $\alpha_p$ and $\alpha_m$ using Eq. (7). Subsequently, the FLOPs are determined using these values. For a detailed explanation, please refer to Sec.B.1 in the Appendix. I believe your confusion can be resolved by consulting Sec.B.1. The code corresponding to this can be found at this link.