Open yueml opened 3 years ago
您好,感谢您的关注~
这里train_trans.txt 中,数据9元组,元组的每一行代表的是: 1)输入的自然语言描述 2)代码的AST (不过用不到) 3)代码AST中每个结点对应的父结点 (用不到) 4)代码AST中每个结点对应的祖父结点 (用不到) 5)用不到 6)已经使用过的规则序列(从1开始标号) 7)用于decoder,需要inference的规则序列(从0开始标号)。当规则号 r 大于等于 10000 时,表示 copy 自然语言描述中 第 r - 10000 个词。 8)用不到 9)程序的抽象语法树的邻接矩阵表示。其中,-1代表无父结点,第 i 个位置的 k 代表第 i 个位置的规则的父结点是第 k 个位置的规则。
train_tree.txt 中,每一行代表预测时候所需要的tree-path信息(query),它会对应着 train_trans.txt 中的每一个规则(对于整个卡牌代码来说是多行对应一行,如果卡牌inference需要 m 步,则是 m 行对应一行)。
Zeyu
您好,非常感谢您耐心细致的解答!我为feed_dict输入的batch对应上train_trans.txt,做了个表格,不知道正不正确。
但我仍有一些困惑 1.下面python代码中feed_dict某些输入搞不明白含义(下面代码注释中标记了问号[?]),希望得到您的解惑。 2.在inference过程中,当AST展开当前节点时,需要符合文法规则,请问这在代码中是怎么实现的呢,或者ta的实现位置在哪。
再次感谢您的解答,祝生活安康,前路似锦!
_, pre, a = sess.run([model.optim, model.correct_prediction, model.cross_entropy], feed_dict={
model.input_NL: batch[0], //输入的自然语言
model.input_NLChar:batch[1], //输入的自然语言字符
model.inputparentlist: batch[5], //?
model.inputrulelist:batch[6], //这里是对应Rule.txt中的信息么?
model.inputrulelistnode:batch[7], //?
model.inputrulelistson:batch[8], //?
model.inputY_Num: batch[9], //要预测出的规则+变量标识符(pointer network)
model.tree_path_vec: batch[12], //tree-path信息,作为decoder的query
model.labels:batch[18], //?
model.loss_mask:loss_mask,
model.antimask: pre_mask(),
model.treemask: batch[16], //?
model.father_mat:batch[17], //?
model.state:state, //?
model.keep_prob: 0.85,
model.rewards: rewards, //?
model.is_train: True
})
您好,
这个代码经过了太多次修改,有些没用的代码没有删除,以及有些变量名不大合适,给您带来了麻烦,非常抱歉。
对于第一个问题和表格:
既然表格和代码注释近似同源,我就把信息补充在注释上了
_, pre, a = sess.run([model.optim, model.correct_prediction, model.cross_entropy], feed_dict={
model.input_NL: batch[0], //输入的自然语言
model.input_NLChar:batch[1], //输入的自然语言字符
model.inputparentlist: batch[5], // 用不到
model.inputrulelist:batch[6], // 已经使用过的规则序列(从1开始标号),对应 9元组 中第 6 行
model.inputrulelistnode:batch[7], // Rule.txt中的信息,每个位置中用于 rule definition encoding 的父结点 (通过规则序号查表可得)
model.inputrulelistson:batch[8], // Rule.txt中的信息,每个位置中用于 rule definition encoding 的子结点(最多10个)(通过规则序号查表可得)
model.inputY_Num: batch[9], //要预测出的规则+变量标识符(pointer network)
model.tree_path_vec: batch[12], //tree-path信息,作为decoder的query
model.labels:batch[18], // 用于 模型中的 depth embedding,表示结点在AST中的深度信息
model.loss_mask:loss_mask,
model.antimask: pre_mask(),
model.treemask: batch[16], // Tree Conv 部分所使用的邻接矩阵
model.father_mat:batch[17], // 用不到
model.state:state, // 用不到
model.keep_prob: 0.85,
model.rewards: rewards, // 用不到
model.is_train: True
})
对于第二个问题:
我们在inference过程中检测了一下:如果预测的规则的父结点和我们要扩展的结点不是同一类型结点,我们就将其过滤掉。
其具体代码位于:predict_HS-B.py
的 582 行:
if i < len(Rule) and Rule[i][0] != JavaOut.Node:
这行表示,如果预测的规则号i
小于预定义规则的总数量(如果大于就会执行指针网络的copy规则) 并且 规则i
的父结点和我们要扩展的结点是同一类型的结点。
再次感谢您的关注~如果有什么问题欢迎继续问我~
Zeyu
您好,我对模型的自回归模式有些疑惑(╯﹏╰)b。
模型的输出 y_result 是整个样本的规则序列,但是在同一个训练样本里,不同的节点(当前预测的rule)使用的tree_path、inputrulelistnode、inputrulelistson和深度信息 labels 是不同的,这是怎么做到的呢?
抱歉,最近太忙了,回复晚了。
其实输出,在训练和 inference的时候是个矩阵 [输出长度, 输出预测结果],而不是一个自回归的 vector,因而可以用矩阵存下每步的信息。
Zeyu您好,@zysszy
我想要请教一个关于第7行数据的问题
7)用于decoder,需要inference的规则序列(从0开始标号)。当规则号 r 大于等于 10000 时,表示 copy 自然语言描述中 第 r - 10000 个词。
在train_trans.txt
中,自然语言是Acidic Swamp Ooze NAME_END 3 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END NEUTRAL PLAYER_CLS_END NIL RACE_END COMMON RARITY_END < b > battlecry : < /b > destroy your opponent 's weapon .
,第7行是0 1 2 3 10003 4 5 10012 6 7 8 9 5 10 11 12 13 14 15 16 13 14 17 5 18 19 20 21 22 23 10003 10008 16 24 5 25 26 16 24 5 27 10018 28 29 10023 13 14 17 5 30 31 14 17 5 32 19 20 14 17 5 33 34 14 17 5 35 19 20 20 20 7 36 37 5 10 5 38 39 40 13 14 17 5 10012 41 10004 10006 20
,其对应的代码是
class AcidicSwampOoze(MinionCard):
def __init__(self):
super().__init__("Acidic Swamp Ooze", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Destroy(), WeaponSelector(EnemyPlayer())))
def create_minion(self, player):
return Minion(3, 2)
那么第7行的10003表示copy自然语言中的第3个时,应该是Ooze
,而不是AcidicSwampOoze
,请问这里是还有别的特殊规则吗?
另外是第一个10012处,应该是其父类MinionCard
,但自然语言中第12个单词是DUR_END
,而不是Minion
。而且Minion
又是如何变为MinionCard
的呢?
非常感谢!
Zeyu您好,@zysszy
我想要请教一个关于第7行数据的问题
7)用于decoder,需要inference的规则序列(从0开始标号)。当规则号 r 大于等于 10000 时,表示 copy 自然语言描述中 第 r - 10000 个词。
在
train_trans.txt
中,自然语言是Acidic Swamp Ooze NAME_END 3 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END NEUTRAL PLAYER_CLS_END NIL RACE_END COMMON RARITY_END < b > battlecry : < /b > destroy your opponent 's weapon .
,第7行是0 1 2 3 10003 4 5 10012 6 7 8 9 5 10 11 12 13 14 15 16 13 14 17 5 18 19 20 21 22 23 10003 10008 16 24 5 25 26 16 24 5 27 10018 28 29 10023 13 14 17 5 30 31 14 17 5 32 19 20 14 17 5 33 34 14 17 5 35 19 20 20 20 7 36 37 5 10 5 38 39 40 13 14 17 5 10012 41 10004 10006 20
,其对应的代码是class AcidicSwampOoze(MinionCard): def __init__(self): super().__init__("Acidic Swamp Ooze", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Destroy(), WeaponSelector(EnemyPlayer()))) def create_minion(self, player): return Minion(3, 2)
那么第7行的10003表示copy自然语言中的第3个时,应该是
Ooze
,而不是AcidicSwampOoze
,请问这里是还有别的特殊规则吗?另外是第一个10012处,应该是其父类
MinionCard
,但自然语言中第12个单词是DUR_END
,而不是Minion
。而且Minion
又是如何变为MinionCard
的呢?非常感谢!
这里是HS-B的环境下,我使用了一些semi structural的information。这里我在处理的时候设置了一个特殊的规则(在HS-B里允许copy一整个字段)。10003 代表从0开始数第3个词,应该是 NAME_END,代表copy一整个name。
MinionCard也是copy了一整个字段,同时在HS-B里,我加了一段后处理,使得MinionCard == Minion。
Zeyu
谢谢您的回复!