zysszy / TreeGen

A Tree-Based Transformer Architecture for Code Generation. (AAAI'20)
MIT License
90 stars 27 forks source link

Are the programs that TreeGen generate guaranteed to be syntactically correct? #11

Open brando90 opened 3 years ago

brando90 commented 3 years ago

Hi Authors of TreeGen!

First thanks again for sharing your code and for the very interesting work!

I was reading your work and noticed that you predict the grammar rules all in one go with a classification layer. I was puzzled by this because my understanding (at least for context free grammars CFGs) is that a re-write rule is sort of "markovian" i.e. from the previous symbol we write it to a next possible symbols based on one single rule - but crucially it is conditioned on the previous non-terminal. However, that doesn't seem to be the case if we have a single classification layer for rules like:

prob_rule = softmax( h_decoder * W_r)

I think I've seen other papers use this same "put all the rules in the final layer" but then I saw that a related paper you cite say (emphasis by me):

More recent work generates programs by predicting the grammar rule or rewriting rule to apply at each step (Xiong et al. 2018; Yin and Neubig 2017; Rabinovich, Stern, and Klein 2017); thus, the generated programs are guaranteed to be syntactically correct.

I was wondering, can you clarify how your work guarantees programs are always syntactically correct like that work claims they do? (Perhaps you don't and that other work does and that is fine but I just wanted to double check).

https://arxiv.org/abs/1811.06837

brando90 commented 3 years ago

Note I don't mean that the current GRU doesn't take the non-terminal in your work (it does) but I would have thought that if we are on non-terminal NT that would index the allowed/valid weights from W_r e.g. something like W_r,NT so that we get:

prob_rule = softmax( h_decoder * W_r,NT)

which now I think this truly restricts to valid rules assuming that W_r,NT corresponds only to valid rules that come from non-terminal NT.

zysszy commented 3 years ago

Sorry for the late reply. I was on holiday for the last three days.

I was wondering, can you clarify how your work guarantees programs are always syntactically correct like that work claims they do? (Perhaps you don't and that other work does and that is fine but I just wanted to double check).

We use the code (post-processing) in Line 500

if i < len(Rule) and Rule[i][0] != JavaOut.Node:

and Line 506

if i >= len(Rule) and JavaOut.Node.strip() not in copy_node:

to guarantee the correctness of the syntax during inference. In these lines, we removed the invalid rules from the prediction rather than indexing the allowed weights from W_r (this has very similar performance compared with our approach).

Zeyu

brando90 commented 3 years ago

Sorry for the late reply. I was on holiday for the last three days.

I was wondering, can you clarify how your work guarantees programs are always syntactically correct like that work claims they do? (Perhaps you don't and that other work does and that is fine but I just wanted to double check).

We use the code (post-processing) in Line 500

if i < len(Rule) and Rule[i][0] != JavaOut.Node:

and Line 506

if i >= len(Rule) and JavaOut.Node.strip() not in copy_node:

to guarantee the correctness of the syntax during inference. In these lines, we removed the invalid rules from the prediction rather than indexing the allowed weights from W_r (this has very similar performance compared with our approach).

Zeyu

Hi Zeyu!

No worries, hope you enjoyed your vacation - us researchers sometimes not rest enough imho!

Ah, so you do somehow use only the weights that lead to valid ASTs it seems. However, during training it seems that you train all the rules together at once. Do you not also "remove" (mask, or index or something like that) the weights that are invalid?

This is the scenario I am worried about:

Do you see what I am worried about? I would have expected that one of the advantages of learning grammar rules compared to a language model that only sees tokens is that by using the grammar the model doesn't have to waste time (implicitly) learning the grammar - however if we train all the rules at once without removing the invalid ones then everything is sort of flatten like a language model & I'd assume we are partially limiting our model (by not leveraging the inductive bias of the grammar).

Does this make sense to you? At the code of my question is if your model is also "removing stuff" during training (since that conditions correctly to use the grammar appropriately).

zysszy commented 3 years ago

Ah, so you do somehow use only the weights that lead to valid ASTs it seems. However, during training it seems that you train all the rules together at once. Do you not also "remove" (mask, or index or something like that) the weights that are invalid?

No, we don't remove the weights that are invalid during training.

I would have expected that one of the advantages of learning grammar rules compared to a language model that only sees tokens is that by using the grammar the model doesn't have to waste time (implicitly) learning the grammar

The programs in our training data are all made up of valid rules. Though we don't remove the invalid weights during training, we found that our model learns only to use valid rules in most cases. It is not just like a language model but a grammar-aware model.

Zeyu

brando90 commented 3 years ago

Ah, so you do somehow use only the weights that lead to valid ASTs it seems. However, during training it seems that you train all the rules together at once. Do you not also "remove" (mask, or index or something like that) the weights that are invalid?

No, we don't remove the weights that are invalid during training.

I would have expected that one of the advantages of learning grammar rules compared to a language model that only sees tokens is that by using the grammar the model doesn't have to waste time (implicitly) learning the grammar

The programs in our training data are all made up of valid rules. Though we don't remove the invalid weights during training, we found that our model learns only to use valid rules in most cases. It is not just like a language model but a grammar-aware model.

Zeyu

Hi Zeyu,

Thanks for your response!

When I said remove I really meant "mask" but I think you understood me. Ok that's ok if that's what your did. That's good it's learns to produce syntactically valid programs but if it sees tokens from the grammar or tokens from the literal program (like NLP does) - then what is the advantage of using a grammar based model at all? As you mentioned it sometimes does wrong programs that are removed during beam search I believe. So the model isn't trained to produce valid program via the grammar since the grammar isn't actually restricting the output of the model in any way - Am I right? What is the advantage of learning the grammar if the sequence is just as arbitrary in structure as feeding the sequence of literals as tokens?

Thanks for your kind responses!

zysszy commented 3 years ago

When I said remove I really meant "mask" but I think you understood me.

Same to me.

What is the advantage of using a grammar-based model at all?

Using grammar rules aims to make the training easier and ensure every step during inference (beam search) is syntactically correct. If you treat the programs as plain texts, it is hard to ensure every decoding step syntactically correct.

Zeyu