junxiaosong / AlphaZero_Gomoku

An implementation of the AlphaZero algorithm for Gomoku (also called Gobang or Five in a Row)
MIT License
3.27k stars 964 forks source link

Tensorflow版本的疑问 #28

Open dshnightmare opened 6 years ago

dshnightmare commented 6 years ago

tensorflow版本中input_state的维度转换直接用reshape应该有问题吧?输入是[batch_size, c, h, w],tensorflow需要的则是[batch_size, h, w, c],直接reshape的话只改变了维度,数据并没有转置

junxiaosong commented 6 years ago

看了下确实有问题,应该还是得用data_format='channels_first',这个tensorFlow版本是其他同学贡献的,当时看了,但没有发现这个问题,感谢指出。不知道是否方便帮忙修一下,提个pull request呢?

dshnightmare commented 6 years ago

好的,没问题,等我把整个代码跑一跑

junxiaosong commented 6 years ago

已合并#29 ,多谢

xinrui-zhuang commented 6 years ago

你好,我在使用TensorFlow版本时,保存的模型是best_policy.model.meta,不知道如何调用这个图模型

junxiaosong commented 6 years ago

@xinrui-zhuang 调用的时候,模型名字直接用"best_policy.model"即可加载

xinrui-zhuang commented 6 years ago

我之前有尝试过这样做,但是总会出现如下的错误。找了半天,不知道如何解决 Traceback (most recent call last): File "human_play.py", line 64, in run policy_param = pickle.load(open(model_file, 'rb')) _pickle.UnpicklingError: invalid load key, ''.

dshnightmare commented 6 years ago

@xinrui-zhuang tensorflow不用pickle那段代码,可以跑

xinrui-zhuang commented 6 years ago

@dshnightmare 你的意思是直接把pickle那段注释掉,把后面换成best_policy = PolicyValueNetNumpy(width, height, model_file) ?但是这样的话还是报错,不是很懂,望详细指教,十分感谢

junxiaosong commented 6 years ago

@xinrui-zhuang 意思是把64-72行全部注释掉,那一段的目的是让大家可以只用numpy就能加载提供的模型并和它对战;然后取消60-61行的注释,这两行是加载自己训练的模型用的,不管你是用的Theano/Lasagne, PyTorch 还是 TensorFlow

xinrui-zhuang commented 6 years ago

感谢,问题解决了,加载模型时使用tf.train.import_meta_graph()就可以使用保存的.model.meta文件了。但是重新训练了一个12*12的棋盘模型,跑了1500个epochs,但是效果还不是很好,可以在哪方面改进一下?

junxiaosong commented 6 years ago

@xinrui-zhuang 感觉1500个epoch对于12*12的大棋盘来说太少了,另外issue #14 #13 中的一些讨论也可以参考

biasbb commented 6 years ago

你好,我想问一下在训练的时候batch的值代表的是当前已进行的对局次数吗?那我训练终止之后把生成的model文件作为参数读到train.py里面开始训练的话总的训练次数是不是就应该是本次训练终止时的次数加上上次结束时的次数呢?

junxiaosong commented 6 years ago

@biasbb 训练的时候一个batch就是采样一个batch_size的数据并更新模型,在现在代码里的默认参数下和已经进行的对局次数是一样的。后面那个问题的回答也是肯定的。

yuan9778 commented 6 years ago

@xinrui-zhuang 请问 “加载模型时使用tf.train.import_meta_graph()就可以使用保存的.model.meta文件了”是什么意思?谢谢!