zysszy / TreeGen

A Tree-Based Transformer Architecture for Code Generation. (AAAI'20)
MIT License
90 stars 27 forks source link

Need help to understand the code #5

Open iamxpy opened 4 years ago

iamxpy commented 4 years ago

First, I would like to thank you for the publication of the source code, which makes it much easier for me to understand the details of the model. I have read several key files of the project, including resolve_data.py, code_generate_model.py and run.py, and now I am kind of stuck on the file predict.py. I have seversal questions and need some help.

(1) What does the symbol ^ mean? It appears both in the input file of training procedure (train_trans.txt) and the output file of inference (e.g. 'ATIS/out/0.txt'), and the code in predict.py insert or remove the symbol many times.

(2) How to understand the content of the output files of predict.py(e.g. ATIS/out/0.txt)? According to the method BeamSearch() in predict.py, I think every 2 lines is a beam, with the first line showing the generated tree and the second line is the probability of the tree. But I failed to understand how the trees are organized into these sequences (with so many^), and what made me more confused, the "probability" are all negative numbers. I have attached a picture showing the content of 0.txt.

image

(3)Which part of this repository is not complete? According to README, "This repository is not a complete one".

(4) The number of training epochs is set to pretrain_times = 100000, is it necessary to use such a big value to achieve the results in the paper? After running for a while, I found out that after 100 epochs, the accuracy is 0.98+, and 0.99+ after 500 epochs, so I just stopped training.

Thank you for your help in advance!

iamxpy commented 4 years ago

After debugging using several test samples, I think I understand the symbol ^ in ATIS/out/n.txt now. According to the method J_AddSon() in predict.py, the symbol ^ is used to match the corresponding key word, we can add some brackets to make it more clear (personally, I think it is better to replace ^s with brackets which will make it much easier to understand and read).

For example, the sequence

root Lambda variable $0 ^ ^ type e ^ ^ body Apply predicate airport ^ ^ arguments Variable variable $0 ^ ^ ^ End ^ ^ ^ ^ ^ ^

equals to

(root (Lambda (variable ($0 ^) ^) (type (e ^) ^) (body (Apply (predicate (airport ^) ^) (arguments (Variable (variable ($0 ^) ^) ^) (End ^) ^) ^) ^) ^) ^)

, and it equals to the following tree:

image

zysszy commented 4 years ago

Thank you for your attention~

(1) What does the symbol ^ mean? Yeah, you got it! Your understanding is absolutely right~

(2) How to understand the content of the output files of predict.py(e.g. ATIS/out/0.txt)? how the trees are organized into these sequences I think the answer is same to the question 1. the "probability" are all negative numbers The probability here is the log of its percentage.

(3) Which part of this repository is not complete? I am sorry for it. We have not provided the experiments of GEO and ATIS in the paper.

(4) The number of training epochs is set to pretrain_times = 100000, is it necessary to use such a big value to achieve the results in the paper? It is not necessary to use such a large value. You can stop training when the accuracy is no longer growing.

I hope my answers can help you understand our project. If you have any questions, please feel free to tell me~

Zeyu

iamxpy commented 4 years ago

Thank you for your attention~

(1) What does the symbol ^ mean? Yeah, you got it! Your understanding is absolutely right~

(2) How to understand the content of the output files of predict.py(e.g. ATIS/out/0.txt)? how the trees are organized into these sequences I think the answer is same to the question 1. the "probability" are all negative numbers The probability here is the log of its percentage.

(3) Which part of this repository is not complete? I am sorry for it. We have not provided the experiments of GEO and ATIS in the paper.

(4) The number of training epochs is set to pretrain_times = 100000, is it necessary to use such a big value to achieve the results in the paper? It is not necessary to use such a large value. You can stop training when the accuracy is no longer growing.

I hope my answers can help you understand our project. If you have any questions, please feel free to tell me~

Zeyu

Thanks for your quick reply, I still have some questions.

(1) Confusion about the meaning of "Javaoutput.Probability".

Specifically, I get confused about the following expression in predict.py:

JavaOutNext.Probility = (JavaOut.Probility * math.pow(len(JavaOut.RuleList), apa) + math.log(max(1e-10, res[i]))) / math.pow(len(JavaOut.RuleList) + 1, apa)

Suppose we decode rule A at time step 1, and get the corresponding probability P(A)(the output of softmax function), and decode rule B at time step 2 with P(B), and so on.

I notice that, at time step 1, the expression above equals to log(P(A)) because len(JavaOut.RuleList) is 0. At time step 2, the expression equals to

image

in which len is 1. I hope I didn't make any stupid mistake in the calculation.

It would be more reasonable to me if the expression is changed to

JavaOutNext.Probility = JavaOut.Probility + math.log(max(1e-10, res[i])) 

, which means Probility stands for logP(A)+logP(B)+...=log[P(A)*P(B)*...], i.e. the log of the joint probability of the rules.

Could you explain more about "the log of its percentage"?

(2) Why do we need the files Tree_Rule.in and Tree_Feature.out? Why can't we just simply pass the content of these files (as a list or dict) to the methods which need them? And most importantly I failed to understand the content of Tree_Feature.out, namly the following snippet:

fw.write(out.replace(" node_gen ^", "") + "\n")
fw.write(node_par[0] + "\n")
fw.write(node_par[1] + "\n")
fw.write(node_par[2] + "\n")
fw.write(out.replace(" End ^", "") + "\n")
fw.write(out.replace(" End ^", "") + "\n")  # Is this redundant?
fw.write("1\n")  # Fixed value?
fw.write("1\n")  # Fixed value?
fw.write(out.replace(" End ^", "") + "\n")
fw.write("1\n")   # Is this useless?
fw.write(str(father_index_now) + "\n")   # Is this redundant?

The method getJavaOut() typically reads the line 4, 1, 2, 3, 0, 6, 7 with the following code

return Javaoutput(lines[4][:-1], Nl, lines[1][:-1], lines[2][:-1], lines[3][:-1], lines[0][:-1],lines[6][:-1], lines[7][:-1], "grow")

So is some of the content of Tree_Feature.out redundant? Also, those 1s are passed to the attributes FatherTree and GrandFatherTree of Javaoutput object, I don't understand why are they just fiexed to 1.

Hoping for your reply!

zysszy commented 4 years ago

(1) Confusion about the meaning of "Javaoutput.Probability". apa denotes the alpha, which is a length penalty in beam search. Details are in https://d2l.ai/chapter_recurrent-modern/beam-search.html (9.8.4).

(2) Why do we need the files Tree_Rule.in and Tree_Feature.out? Actually, we don't need these two files. This python code is converted from Java, and I left these two files to avoid introducing new bugs. With the development of this research, we have left a lot of useless code, which is hard to remove, in this project.

Zeyu

iamxpy commented 4 years ago

@zysszy Many thanks for helping me out so much!

iamxpy commented 4 years ago

@zysszy Sorry to bother you again, I have another question. As you've mentioned in #3

we always expand the leftmost non-terminal node until all leaf nodes are terminals

I can confirm that this is true when doing inference in predict.py, but maybe I missed something and I think it is not the case in the training process. For example, the whole AST of the first training sample (in train_trans.txt and train_tree.txt) looks like this:

微信图片_20201027204223

The numbers beside every non-terminal nodes indicate their order which decide which tokens they can pay attention to under the control of the antimask, a lower triangular matrix. In this example, after expanding 8th node args (short for arguments), the model expands 9th node args instead of the node Var (the leftmost non-terminal node after expanding the first args).

What am I doing wrong here? Thanks!

zysszy commented 4 years ago

Sorry for the late reply.

The Node End, which denotes the end of the child nodes of Apply, is a child node of the node Apply. Thus, the 8th node args is still the leftmost non-terminal node.

Zeyu

iamxpy commented 4 years ago

I still feel confused about the Node End and the meaning of leftmost.

First of all, I would like to clarify that the data of the first example is as follows.

airport
root Lambda variable $0 $0^ variable^ type e e^ type^ body Apply predicate airport airport^ predicate^ arguments Variable variable node_gen node_gen^ variable^ Variable^ arguments^ Apply^ body^ Lambda^ root^
Unknown root Lambda variable variable^ Lambda^ Lambda type type^ Lambda^ Lambda body Apply predicate predicate^ Apply^ Apply arguments Variable variable variable^ Variable^ arguments^ Apply^ body^ Lambda^ root^ Unknown^
Unknown Unknown root Lambda Lambda^ root^ root Lambda Lambda^ root^ root Lambda body Apply Apply^ body^ body Apply arguments Variable Variable^ arguments^ Apply^ body^ Lambda^ root^ Unknown^ Unknown^
variable  Variable  arguments  Apply  body  Lambda  root
1 2 3 4 5 6 7 8 9 10 3
0 1 2 3 4 5 6 7 8 9 2 
9
-1 0 1 1 1 4 5 5 5 7 9 

Please correct me if I am wrong. There's a strong possibility that I have misunderstood something important.

(1) Question about the Node End

The Node End, which denotes the end of the child nodes of Apply, is a child node of the node Apply

There is no rule with Apply on the left and End on the right in file ATIS/Rule.txt, however the rule arguments true End do exist and it is exactly the rule used in the first example (9 in the 6th line). Also, a List called J_NeedsEnd is used in the method J_isend(), and the list only contains arguments. So I think End is used to denote the end of the child nodes of arguments instead of Apply and it can be a child node of the node arguments as is shown in the picture I posted in last comment.

(2) The meaning of leftmost.

2.1) When describing an AST using a tree, the leftmost node is the first non-terminal node (which still needs expanding) we meet in DFS. So after expanding 8th node args, the leftmost node should be its child node, the 10th node Var.

2.2) When denoting an AST using sequence with symbol ^ indicating the hierarchy, leftmost literally means the leftmost non-terminal node. See the method J_scan() in predict.py, which scans the sequence from left to right and expands the first non-terminal node. After expanding 8th node args, the sequence would look like this: ... args Var ^ ^ args ^ ... and Var is the leftmost non-terminal node instead of the args node on its right.

Looking forward to your reply!

zysszy commented 4 years ago

I found the figure is wrong. 8th node args and 9th args are the same node. Thus, we first expand the 8th node with the Rule 7 (arguments true Variable). Then, all children of the 8th node were generated, and we expand the 8th node with the Rule 8 (arguments true End). So, the next step is to expand the 10th node Var (the leftmost non-terminal node).

We expand a node until all of its children are generated. The children of the node args in the example needs two steps to generate (Rule 7 and 8). In these types of nodes, we first finish the generation of all steps. After finish it, we then select the next leftmost node to expand.

@zysszy Sorry to bother you again, I have another question. As you've mentioned in #3

we always expand the leftmost non-terminal node until all leaf nodes are terminals

I can confirm that this is true when doing inference in predict.py, but maybe I missed something and I think it is not the case in the training process. For example, the whole AST of the first training sample (in train_trans.txt and train_tree.txt) looks like this:

微信图片_20201027204223

The numbers beside every non-terminal nodes indicate their order which decide which tokens they can pay attention to under the control of the antimask, a lower triangular matrix. In this example, after expanding 8th node args (short for arguments), the model expands 9th node args instead of the node Var (the leftmost non-terminal node after expanding the first args).

What am I doing wrong here? Thanks!

iamxpy commented 4 years ago

I found the figure is wrong. 8th node args and 9th args are the same node. Thus, we first expand the 8th node with the Rule 7 (arguments true Variable). Then, all children of the 8th node were generated, and we expand the 8th node with the Rule 8 (arguments true End). So, the next step is to expand the 10th node Var (the leftmost non-terminal node).

How to understand the last line of the first example, -1 0 1 1 1 4 5 5 5 7 9? Based on the code, I think after adding 1 to every number and appending a -1 to the left, we can get the parent array representation of the AST, i.e. -1 0 1 2 2 2 5 6 6 6 8 10. We can see that the 6th node has 3 children including the 8th node and the 9th node. I built the tree based on the parent array(the last line of every example) and the rules(the 6th line of every example), what am I doing wrong? Thx!

zysszy commented 4 years ago

I found the figure is wrong. 8th node args and 9th args are the same node. Thus, we first expand the 8th node with the Rule 7 (arguments true Variable). Then, all children of the 8th node were generated, and we expand the 8th node with the Rule 8 (arguments true End). So, the next step is to expand the 10th node Var (the leftmost non-terminal node).

How to understand the last line of the first example, -1 0 1 1 1 4 5 5 5 7 9? Based on the code, I think after adding 1 to every number and appending a -1 to the left, we can get the parent array representation of the AST, i.e. -1 0 1 2 2 2 5 6 6 6 8 10. We can see that the 6th node has 3 children including the 8th node and the 9th node. I built the tree based on the parent array(the last line of every example) and the rules(the 6th line of every example), what am I doing wrong? Thx!

Yeah, you are right if you remove the node start and use the node root as the 0th node~

iamxpy commented 4 years ago

I found the figure is wrong. 8th node args and 9th args are the same node. Thus, we first expand the 8th node with the Rule 7 (arguments true Variable). Then, all children of the 8th node were generated, and we expand the 8th node with the Rule 8 (arguments true End). So, the next step is to expand the 10th node Var (the leftmost non-terminal node).

How to understand the last line of the first example, -1 0 1 1 1 4 5 5 5 7 9? Based on the code, I think after adding 1 to every number and appending a -1 to the left, we can get the parent array representation of the AST, i.e. -1 0 1 2 2 2 5 6 6 6 8 10. We can see that the 6th node has 3 children including the 8th node and the 9th node. I built the tree based on the parent array(the last line of every example) and the rules(the 6th line of every example), what am I doing wrong? Thx!

Yeah, you are right if you remove the node start and use the node root as the 0th node~

I understand that the 6th line of every sample corresponds to the tree without start. But let us just use the the tree with start so we don't have to change(minus 1) the indices. I still don't understand why can we conclude that 8th node args and 9th args are the same node.

zysszy commented 4 years ago

I found the figure is wrong. 8th node args and 9th args are the same node. Thus, we first expand the 8th node with the Rule 7 (arguments true Variable). Then, all children of the 8th node were generated, and we expand the 8th node with the Rule 8 (arguments true End). So, the next step is to expand the 10th node Var (the leftmost non-terminal node).

How to understand the last line of the first example, -1 0 1 1 1 4 5 5 5 7 9? Based on the code, I think after adding 1 to every number and appending a -1 to the left, we can get the parent array representation of the AST, i.e. -1 0 1 2 2 2 5 6 6 6 8 10. We can see that the 6th node has 3 children including the 8th node and the 9th node. I built the tree based on the parent array(the last line of every example) and the rules(the 6th line of every example), what am I doing wrong? Thx!

Yeah, you are right if you remove the node start and use the node root as the 0th node~

I understand that the 6th line of every sample corresponds to the tree without start. But let us just use the the tree with start so we don't have to change(minus 1) the indices. I still don't understand why can we conclude that 8th node args and 9th args are the same node.

Hmmm... Sorry, it is a bug..... I think you are right. I set a wrong parent node, which is a grandparent node, to this rule. This bug is not an important bug. It will not hurt the effectiveness of TreeGen.

Thank you for this.

Zeyu

iamxpy commented 4 years ago

@zysszy Thank you soooo much for your valuable time!

zhaodezhu111 commented 1 year ago

@zysszy I want to ask if this method of AST node traversal into a sequence is from that paper. Is there a paper that officially introduces this method? My current work uses this AST traversal method, and I want to quote the reference of this method.

zysszy commented 1 year ago

@zysszy I want to ask if this method of AST node traversal into a sequence is from that paper. Is there a paper that officially introduces this method? My current work uses this AST traversal method, and I want to quote the reference of this method.

We directly use the traditional preorder traversal algorithm. Maybe you can refer to the Wikipedia (https://en.wikipedia.org/wiki/Tree_traversal).

zhaodezhu111 commented 1 year ago

OK, thank you for your reply!