xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
798 stars 77 forks source link

Understanding loss landscape #40

Closed SonHyegang closed 7 months ago

SonHyegang commented 7 months ago

I understood that in the loss landscape visualization the z-axis is NLL. I'm curious what the x-axis and y-axis mean. Of course, we can see in loss_landscapes.py how the x and y values ​​participate in the calculation, but I don't have an intuitive understanding of it.

    xs = np.linspace(x_min, x_max, n_x)
    ys = np.linspace(y_min, y_max, n_y)
    ratio_grid = np.stack(np.meshgrid(xs, ys), axis=0).transpose((1, 2, 0))
    print(ratio_grid)
    metrics_grid = {}
    for ratio in ratio_grid.reshape([-1, 2]):
        print(ratio)
        ws = copy.deepcopy(ws0)
        gs = [{k: r * bs[k] for k in bs} for r, bs in zip(ratio, bases)]
        gs = {k: torch.sum(torch.stack([g[k] for g in gs]), dim=0) + ws[k] for k in gs[0]}
        print(gs)
        model.load_state_dict(gs)

        print("Grid: ", ratio, end=", ")
        *metrics, cal_diag = tests.test(model, n_ff, dataset, transform=transform,
                                        cutoffs=cutoffs, bins=bins, verbose=verbose, period=period, gpu=gpu)
        l1, l2 = norm.l1(model, gpu).item(), norm.l2(model, gpu).item()
        metrics_grid[tuple(ratio)] = (l1, l2, *metrics)

    return metrics_grid

Thank you sincerely.

xxxnell commented 7 months ago

Hi @SonHyegang, thank you for reaching out with a great question. Sorry for the late reply; it has been a hectic week for me.

Essentially, we utilize two randomly initialized neural networks as the bases, or simply x- or y- axes, for the visualizations of the loss landscape. In other words, intuitively speaking, the boundaries of these visualizations can be correspond to the random initializations of the neural networks. Likewise, the training trajectories of the neural network can be conceptualized as paths that moves from these boundaries to the center of the visualizations. I hope this helps you.