justincui03 / tesla

MIT License
18 stars 2 forks source link

The partial derivative of loss with respect to x in equation (5)(6) in paper #4

Open j-cyoung opened 2 months ago

j-cyoung commented 2 months ago

Thanks for your great work! The method is impressive and the result is amazing. I have a question about the derivation about the partial derivative of loss with respect to x in equation (5)(6) mentioned in the paper.

image image

As described in the paper "dataset distillation" by Tongzhou Wang, the $\theta$ is a function of $x$, so the partial derivative of loss with respect to x can be writen as $$\frac{\partial l(\theta(x), x)}{\partial x}=\frac{\partial l}{\partial \theta}\frac{\partial \theta}{\partial x}+\frac{\partial l(\theta(x), x)}{\partial x}\frac{\partial x}{\partial x}=\frac{\partial l}{\partial \theta}\frac{\partial \theta}{\partial x}+\frac{\partial l(\theta(x), x)}{\partial x},$$ in which the partial derivative about the second term $\frac{\partial l(\theta(x), x)}{\partial x}$ is only computed for $x$ not for $\theta(x)$.

However, I think the gradient with respect to $x$ as shown in equation (6) doesn't take the $\theta(x)$ into account and treats it as a constant. Maybe I miss something or I get a wrong derivation, so can you help to solve my question? Thanks a lot.

justincui03 commented 2 months ago

Hi, Chenyang,

Thank you very much for your interest in our work. Which equation are you referring to in Tongzhou's paper? I am not sure if I understand your question correctly, but I think we are not treating $\theta(x)$ as a constant as MTT's loss depends on the gradients computed using $X$ which is similar to DSA, maybe the writing is confusing because in DSA (algorithm 1), it uses y in the loss whereas in MTT (algorithm), it uses $\hat{\theta}$.

j-cyoung commented 2 months ago

Thanks for your reply. Chapter 3.1 in Tongzhou's paper mentions that: "We derive the new weights $\theta_1$ as a function of distilled data $\tilde{x}$", in which $\tilde{x}$ means the distilled image. image

Tongzhou's paper then introduce more steps on the synthetic images, obtaining $\theta_1$ from $\tilde x_0$, then obtaining $\theta_2$ from $\tilde x_1$ based on $\theta_2$, etc, assuming each batch size is 1. In this case, I think the $\theta_2, \theta_3, \cdots$ are also the function of $\tilde{x}_0$.


However, in the derivation of formula (6), the paper seems to simplify

\frac{\partial}{\partial \hat{X}_i}\sum_{j=0}^{T-1} \nabla_\theta \ell\left(\hat{\theta}_{t+j} ; \tilde{X}_j\right)

to

\frac{\partial}{\partial \hat{X}_i}\nabla_\theta \ell\left(\hat{\theta}_{t+i} ; \tilde{X}_i\right),

which I think the exact form should be

\frac{\partial}{\partial \hat{X}_i}\sum_{j=i}^{T-1} \nabla_\theta \ell\left(\hat{\theta}_{t+j} ; \tilde{X}_j\right),

since all the $\hat{\theta}_{t+j}(j\geq i)$ should be the function of $\tilde{X}_i$. To put it another way, $Xi$ should be the leaf node in the computation graphs of all the $\frac{\partial}{\partial \hat{X}_i}\nabla_\theta \ell\left(\hat{\theta}_{t+j} ; \tilde{X}_j\right)(j\geq i)$. So I still don't understand the drop of other $\nabla_\theta \ell\left(\hat{\theta}\{t+j} ; \tilde{X}_j\right)$ in equation (6).


Sorry for the late reply, I'm new to dataset distillation and managed to read the paper and code of DD, DC, MTT etc. Then I check the formula (6) again and still can't understand it.