zysszy / TreeGen

A Tree-Based Transformer Architecture for Code Generation. (AAAI'20)
MIT License
90 stars 27 forks source link

train_trans.txt和train_tree.txt里面数据含义和作用是什么 #8

Open yueml opened 3 years ago

yueml commented 3 years ago
    您好,非常感谢你分享的代码,我对输入数据有些疑问,以HS-B数据为例,train_trans.txt一个卡牌数据是有9行,第一行是自然语言描述,第二行是卡牌代码的ast树,我这样理解对么。其它行代表的是什么呢?它们的作用又分别是什么?
    另外,train_tree.txt每一行代表的是什么,它貌似和train_trans.txt有着联系。
    期待您的解答,不胜感激!
zysszy commented 3 years ago

您好,感谢您的关注~

这里train_trans.txt 中,数据9元组,元组的每一行代表的是: 1)输入的自然语言描述 2)代码的AST (不过用不到) 3)代码AST中每个结点对应的父结点 (用不到) 4)代码AST中每个结点对应的祖父结点 (用不到) 5)用不到 6)已经使用过的规则序列(从1开始标号) 7)用于decoder,需要inference的规则序列(从0开始标号)。当规则号 r 大于等于 10000 时,表示 copy 自然语言描述中 第 r - 10000 个词。 8)用不到 9)程序的抽象语法树的邻接矩阵表示。其中,-1代表无父结点,第 i 个位置的 k 代表第 i 个位置的规则的父结点是第 k 个位置的规则。

train_tree.txt 中,每一行代表预测时候所需要的tree-path信息(query),它会对应着 train_trans.txt 中的每一个规则(对于整个卡牌代码来说是多行对应一行,如果卡牌inference需要 m 步,则是 m 行对应一行)。

Zeyu

yueml commented 3 years ago

您好,非常感谢您耐心细致的解答!我为feed_dict输入的batch对应上train_trans.txt,做了个表格,不知道正不正确。

无标题

但我仍有一些困惑 1.下面python代码中feed_dict某些输入搞不明白含义(下面代码注释中标记了问号[?]),希望得到您的解惑。 2.在inference过程中,当AST展开当前节点时,需要符合文法规则,请问这在代码中是怎么实现的呢,或者ta的实现位置在哪。

再次感谢您的解答,祝生活安康,前路似锦!

_, pre, a = sess.run([model.optim, model.correct_prediction, model.cross_entropy], feed_dict={
                             model.input_NL: batch[0],    //输入的自然语言
                             model.input_NLChar:batch[1],    //输入的自然语言字符
                             model.inputparentlist: batch[5],    //?
                             model.inputrulelist:batch[6],    //这里是对应Rule.txt中的信息么?
                             model.inputrulelistnode:batch[7],    //?
                             model.inputrulelistson:batch[8],    //?
                             model.inputY_Num: batch[9],    //要预测出的规则+变量标识符(pointer network)
                             model.tree_path_vec: batch[12],    //tree-path信息,作为decoder的query
                             model.labels:batch[18],    //?
                             model.loss_mask:loss_mask,
                             model.antimask: pre_mask(),
                             model.treemask: batch[16],    //?
                             model.father_mat:batch[17],    //?
                             model.state:state,    //?
                             model.keep_prob: 0.85,
                             model.rewards: rewards,    //?
                             model.is_train: True
                                                  })
zysszy commented 3 years ago

您好,

这个代码经过了太多次修改,有些没用的代码没有删除,以及有些变量名不大合适,给您带来了麻烦,非常抱歉。

对于第一个问题和表格:

既然表格和代码注释近似同源,我就把信息补充在注释上了

_, pre, a = sess.run([model.optim, model.correct_prediction, model.cross_entropy], feed_dict={
                             model.input_NL: batch[0],    //输入的自然语言
                             model.input_NLChar:batch[1],    //输入的自然语言字符
                             model.inputparentlist: batch[5],    // 用不到
                             model.inputrulelist:batch[6],    // 已经使用过的规则序列(从1开始标号),对应 9元组 中第 6 行
                             model.inputrulelistnode:batch[7],    // Rule.txt中的信息,每个位置中用于 rule definition encoding 的父结点 (通过规则序号查表可得)
                             model.inputrulelistson:batch[8],    // Rule.txt中的信息,每个位置中用于 rule definition encoding 的子结点(最多10个)(通过规则序号查表可得)
                             model.inputY_Num: batch[9],    //要预测出的规则+变量标识符(pointer network)
                             model.tree_path_vec: batch[12],    //tree-path信息,作为decoder的query
                             model.labels:batch[18],    // 用于 模型中的 depth embedding,表示结点在AST中的深度信息
                             model.loss_mask:loss_mask,
                             model.antimask: pre_mask(),
                             model.treemask: batch[16],    // Tree Conv 部分所使用的邻接矩阵
                             model.father_mat:batch[17],    // 用不到
                             model.state:state,    // 用不到
                             model.keep_prob: 0.85,
                             model.rewards: rewards,    // 用不到
                             model.is_train: True
                                                  })

对于第二个问题:

我们在inference过程中检测了一下:如果预测的规则的父结点和我们要扩展的结点不是同一类型结点,我们就将其过滤掉。 其具体代码位于:predict_HS-B.py 的 582 行:

if i < len(Rule) and Rule[i][0] != JavaOut.Node:

这行表示,如果预测的规则号i小于预定义规则的总数量(如果大于就会执行指针网络的copy规则) 并且 规则i的父结点和我们要扩展的结点是同一类型的结点。

再次感谢您的关注~如果有什么问题欢迎继续问我~

Zeyu

yueml commented 3 years ago

您好,我对模型的自回归模式有些疑惑(╯﹏╰)b。

模型的输出 y_result 是整个样本的规则序列,但是在同一个训练样本里,不同的节点(当前预测的rule)使用的tree_path、inputrulelistnode、inputrulelistson和深度信息 labels 是不同的,这是怎么做到的呢?

zysszy commented 3 years ago

抱歉,最近太忙了,回复晚了。

其实输出,在训练和 inference的时候是个矩阵 [输出长度, 输出预测结果],而不是一个自回归的 vector,因而可以用矩阵存下每步的信息。

boyang9602 commented 3 years ago

Zeyu您好,@zysszy

我想要请教一个关于第7行数据的问题

7)用于decoder,需要inference的规则序列(从0开始标号)。当规则号 r 大于等于 10000 时,表示 copy 自然语言描述中 第 r - 10000 个词。

train_trans.txt中,自然语言是Acidic Swamp Ooze NAME_END 3 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END NEUTRAL PLAYER_CLS_END NIL RACE_END COMMON RARITY_END < b > battlecry : < /b > destroy your opponent 's weapon .,第7行是0 1 2 3 10003 4 5 10012 6 7 8 9 5 10 11 12 13 14 15 16 13 14 17 5 18 19 20 21 22 23 10003 10008 16 24 5 25 26 16 24 5 27 10018 28 29 10023 13 14 17 5 30 31 14 17 5 32 19 20 14 17 5 33 34 14 17 5 35 19 20 20 20 7 36 37 5 10 5 38 39 40 13 14 17 5 10012 41 10004 10006 20,其对应的代码是

class AcidicSwampOoze(MinionCard):
    def __init__(self):
        super().__init__("Acidic Swamp Ooze", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Destroy(), WeaponSelector(EnemyPlayer())))

    def create_minion(self, player):
        return Minion(3, 2)

那么第7行的10003表示copy自然语言中的第3个时,应该是Ooze,而不是AcidicSwampOoze,请问这里是还有别的特殊规则吗?

另外是第一个10012处,应该是其父类MinionCard,但自然语言中第12个单词是DUR_END,而不是Minion。而且Minion又是如何变为MinionCard的呢?

非常感谢!

zysszy commented 3 years ago

Zeyu您好,@zysszy

我想要请教一个关于第7行数据的问题

7)用于decoder,需要inference的规则序列(从0开始标号)。当规则号 r 大于等于 10000 时,表示 copy 自然语言描述中 第 r - 10000 个词。

train_trans.txt中,自然语言是Acidic Swamp Ooze NAME_END 3 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END NEUTRAL PLAYER_CLS_END NIL RACE_END COMMON RARITY_END < b > battlecry : < /b > destroy your opponent 's weapon .,第7行是0 1 2 3 10003 4 5 10012 6 7 8 9 5 10 11 12 13 14 15 16 13 14 17 5 18 19 20 21 22 23 10003 10008 16 24 5 25 26 16 24 5 27 10018 28 29 10023 13 14 17 5 30 31 14 17 5 32 19 20 14 17 5 33 34 14 17 5 35 19 20 20 20 7 36 37 5 10 5 38 39 40 13 14 17 5 10012 41 10004 10006 20,其对应的代码是

class AcidicSwampOoze(MinionCard):
    def __init__(self):
        super().__init__("Acidic Swamp Ooze", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Destroy(), WeaponSelector(EnemyPlayer())))

    def create_minion(self, player):
        return Minion(3, 2)

那么第7行的10003表示copy自然语言中的第3个时,应该是Ooze,而不是AcidicSwampOoze,请问这里是还有别的特殊规则吗?

另外是第一个10012处,应该是其父类MinionCard,但自然语言中第12个单词是DUR_END,而不是Minion。而且Minion又是如何变为MinionCard的呢?

非常感谢!

这里是HS-B的环境下,我使用了一些semi structural的information。这里我在处理的时候设置了一个特殊的规则(在HS-B里允许copy一整个字段)。10003 代表从0开始数第3个词,应该是 NAME_END,代表copy一整个name。

MinionCard也是copy了一整个字段,同时在HS-B里,我加了一段后处理,使得MinionCard == Minion。

Zeyu

boyang9602 commented 3 years ago

谢谢您的回复!