ilia10000 / dataset-distillation

Soft-Label Dataset Distillation and Text Dataset Distillation
MIT License
73 stars 6 forks source link

The weird thing of the backward function #7

Closed MezereonXP closed 1 year ago

MezereonXP commented 1 year ago

Hello, I have cloned this repo and try to understand the code.

However, I have found some weird things in the Trainer class of the train_distilled_image.py

That Trainer class has a method named backward

def backward(self, model, rdata, rlabel, steps, saved_for_backward):
    l, params, gws = saved_for_backward
    # ....

The param and gws come from the forward function, but they have different length!

I have inserted the print code in the forward function like:

def forward(self, model, rdata, rlabel, steps):
    # .... code ....
    print(f"params's length is {len(params)}")
    print(f"gws's length is {len(gws)}")
    return ll, (ll, params, gws)

and run this command:

python main.py --mode distill_basic --dataset Cifar10 --arch AlexCifarNet  --distill_lr 0.001 --train_nets_type known_init --n_nets 1  --test_nets_type same_as_train

You will see the log:

params's length is 31
gws's length is 30

In the backward function, there is a zip method

for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):

zip(steps, params, gws) will return a shorter list. It ignores the final elements of params.

Question-1: Is that a mistake? Will that final element of the params affect the training?

In the backward function:

for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):
            # hvp_in are the tensors we need gradients w.r.t. final L:
            #   lr (if learning)
            #   data
            #   ws (PRE-GD) (needed for next step)
            #
            # source of gradients can be from:
            #   gw, the gradient in this step, whose gradients come from:
            #     the POST-GD updated ws
            hvp_in = [w]
            if not state.freeze_data:
                hvp_in.append(data)
            hvp_in.append(lr)
            if not state.static_labels:
                hvp_in.append(label)
            dgw = dw.neg()  # gw is already weighted by lr, so simple negation
            hvp_grad = torch.autograd.grad(
                outputs=(gw,),
                inputs=hvp_in,
                grad_outputs=(dgw,),
                retain_graph=True
            )

In the first iteration: Here, the first w is params[-2] and the hvp_grad contains the gradient of gw respect to param[-2]. However, the first dgw is the gradient of the loss respect to param[-1].

I cannot fully understand the meaning of the hvp_grad. (Newton method?)

Question-2: The logic of hvg_grad is hard to understand. Could you please explain the detail of that gradients?

MezereonXP commented 1 year ago

I have understood the logic! I will close that issue and present the details of that part.

Notation

The synthetic data: $s_1, s_2, ..., s_n$

Real data: $(x,y)$ Assume the initial parameter is $w_0$

Use the gradient descent after using every synthetic sample:

w_t = w_{t-1} - \eta_{t-1} \nabla_{w_{t-1}}L_t

As for the real data $(x,y)$ , the corresponding loss with $w_{n}$ is:

L = L(f(x;w_{n}), y)

We could use this loss $L$ to update the syhthetic data.

As for $s_n$ , we need the gradient to update it:

\frac{\partial L}{\partial s_n} = \frac{\partial L}{\partial w_n}\cdot\frac{\partial w_n}{\partial s_n}

Since

w_n = w_{n-1} - \eta_{n-1} \nabla_{w_{n-1}}L_{n-1}

Then we have:

\frac{\partial w_n}{\partial s_n} =  \frac{\partial}{\partial s_n} (- \eta_{n-1}\nabla_{w_{n-1}}L_{n-1})

The hvp_grad is the gradient $\frac{\partial L}{\partial s_n}$ (When we only consider the synthetic data without learnable labels and leraning rates)