labmlai / annotated_deep_learning_paper_implementations

🧑‍🏫 60+ Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit, ...), optimizers (adam, adabelief, sophia, ...), gans(cyclegan, stylegan2, ...), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, ... 🧠
https://nn.labml.ai
MIT License
56.46k stars 5.79k forks source link

Question about gatv2 code #228

Open XiaokangORCA opened 1 year ago

XiaokangORCA commented 1 year ago

Hello, I am a beginner in GAT , and I've been studying your GATv2 code lately. I have a question while going through the code in

labml_nn/graphs/gatv2/init.py

When calculating g_sum

g_sum = g_l_repeat + g_r_repeat_interleave

You mentioned in the comments: Now we add the two tensors to get

$$ \lbrace\overrightarrow{g{l1}} + \overrightarrow{g{r1}}, \overrightarrow{g{l1}} + \overrightarrow{g{r2}}, \dots, \overrightarrow{g{l1}} + \overrightarrow{g{rN}}, \overrightarrow{g{l2}} + \overrightarrow{g{r1}}, \overrightarrow{g{l2}} + \overrightarrow{g{r2}}, \dots, \overrightarrow{g{l2}} + \overrightarrow{g{rN}}, \dots\rbrace $$

But in the previous code, g_l_repeat gets

$$ \lbrace\overrightarrow{g{l1}}, \overrightarrow{g{l2}}, \dots, \overrightarrow{g{lN}}, \overrightarrow{g{l1}}, \overrightarrow{g{l2}}, \dots, \overrightarrow{g{lN}}, \dots\rbrace $$

and g_r_repeat_interleave gets

$$ \lbrace\overrightarrow{g{r1}}, \overrightarrow{g{r1}}, \dots, \overrightarrow{g{r1}}, \overrightarrow{g{r2}}, \overrightarrow{g{r2}}, \dots, \overrightarrow{g{r2}}, \dots\rbrace $$

So I think the result of adding the two tensors should be

$$ \lbrace\overrightarrow{g{l1}} + \overrightarrow{g{r1}}, \overrightarrow{g{l2}} + \overrightarrow{g{r2}}, \dots, \overrightarrow{g{lN}} + \overrightarrow{g{r1}}, \overrightarrow{g{l1}} + \overrightarrow{g{r2}}, \overrightarrow{g{l2}} + \overrightarrow{g{r2}}, \dots, \overrightarrow{g{lN}} + \overrightarrow{g{r2}}, \dots\rbrace $$

I'm not sure whether I may have overlooked some crucial information or if there's a mismatch between your comments and the code. I would greatly appreciate it if you could help clarify my confusion. Thank you.

rjavierch commented 6 months ago

Hello! I am also new to GAT, I found your issue.

So, to your question, the implementation in the website is correct (partially), I think this is because

g_l_repeat

$${\overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, \overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, ...}$$

is:

>>> n_nodes = 3 # or N
>>> torch.tensor([[1], [2], [3]])
>>> tensor.repeat(n_nodes , 1)
tensor([[1],
        [2],
        [3],
        [1],
        [2],
        [3],
        [1],
        [2],
        [3]])

and g_r_repeat_interleave

$${\overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...}$$

is instead:

>>> tensor.repeat_interleave(n_nodes, dim=0)
tensor([[1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [3],
        [3],
        [3]])

So, the operation g_l_repeat + g_r_repeat_interleave

$${\overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, \overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, ...}$$

$$\ + $$

$${\overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...}$$

is

>>> tensor.repeat(n_nodes , 1) + tensor.repeat_interleave(n_nodes, dim=0)
tensor([[1] + [1],
        [2] + [1],
        [3] + [1],
        [1] + [2],
        [2] + [2],
        [3] + [2],
        [1] + [3],
        [2] + [3],
        [3] + [3]])

So, this is correct:

$${\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_N}, ...}$$

But, if you want to match the notation (to avoid confusion), should (I think) be this. However, the current implementation is correct:

$${\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_l}_N} +\overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_2}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_N} + \overrightarrow{{g_r}_2}, ...}$$