Zeta36 / chess-alpha-zero

Chess reinforcement learning by AlphaGo Zero methods.
MIT License
2.14k stars 481 forks source link

Error in notebook #76

Open reikdas opened 6 years ago

reikdas commented 6 years ago

Running cell 6 in notebooks/demo.ipynb gives the following error-

Using TensorFlow backend.
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1625   try:
-> 1626     c_op = c_api.TF_FinishOperation(op_desc)
   1627   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 1 but is rank 0 for 'input_batchnorm/cond/Reshape_4' (op: 'Reshape') with input shapes: [1,256,1,1], [].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-6-9b6eecc00968> in <module>
      1 if not me_player:
----> 2     me_player = get_player(default_config)
      3 action = me_player.action(env, False)
      4 print(f"bestmove {action}")

<ipython-input-3-a13dfe07b64a> in get_player(config)
      3     from chess_zero.lib.model_helper import load_best_model_weight
      4     model = ChessModel(config)
----> 5     if not load_best_model_weight(model):
      6         raise RuntimeError("Best model not found!")
      7     return ChessPlayer(config, model.get_pipes(config.play.search_threads))

~/dev/chess-alpha-zero/src/chess_zero/lib/model_helper.py in load_best_model_weight(model)
     13     :return:
     14     """
---> 15     return model.load(model.config.resource.model_best_config_path, model.config.resource.model_best_weight_path)
     16 
     17 

~/dev/chess-alpha-zero/src/chess_zero/agent/model_chess.py in load(self, config_path, weight_path)
    143             logger.debug(f"loading model from {config_path}")
    144             with open(config_path, "rt") as f:
--> 145                 self.model = Model.from_config(json.load(f))
    146             self.model.load_weights(weight_path)
    147             self.model._make_predict_function()

~/anaconda3/lib/python3.6/site-packages/keras/engine/network.py in from_config(cls, config, custom_objects)
   1030                 if layer in unprocessed_nodes:
   1031                     for node_data in unprocessed_nodes.pop(layer):
-> 1032                         process_node(layer, node_data)
   1033 
   1034         name = config.get('name')

~/anaconda3/lib/python3.6/site-packages/keras/engine/network.py in process_node(layer, node_data)
    989             # and building the layer if needed.
    990             if input_tensors:
--> 991                 layer(unpack_singleton(input_tensors), **kwargs)
    992 
    993         def process_layer(layer_data):

~/anaconda3/lib/python3.6/site-packages/keras/engine/base_layer.py in __call__(self, inputs, **kwargs)
    455             # Actually call the layer,
    456             # collecting output(s), mask(s), and shape(s).
--> 457             output = self.call(inputs, **kwargs)
    458             output_mask = self.compute_mask(inputs, previous_mask)
    459 

~/anaconda3/lib/python3.6/site-packages/keras/layers/normalization.py in call(self, inputs, training)
    204         return K.in_train_phase(normed_training,
    205                                 normalize_inference,
--> 206                                 training=training)
    207 
    208     def get_config(self):

~/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in in_train_phase(x, alt, training)
   3121 
   3122     # else: assume learning phase is a placeholder tensor.
-> 3123     x = switch(training, x, alt)
   3124     if uses_learning_phase:
   3125         x._uses_learning_phase = True

~/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in switch(condition, then_expression, else_expression)
   3056         x = tf.cond(condition,
   3057                     then_expression_fn,
-> 3058                     else_expression_fn)
   3059     else:
   3060         # tf.where needs its condition tensor

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    486                 'in a future version' if date is None else ('after %s' % date),
    487                 instructions)
--> 488       return func(*args, **kwargs)
    489     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    490                                        _add_deprecated_arg_notice_to_docstring(

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
   2085     try:
   2086       context_f.Enter()
-> 2087       orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
   2088       if orig_res_f is None:
   2089         raise ValueError("false_fn must have a return value.")

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildCondBranch(self, fn)
   1918     """Add the subgraph defined by fn() to the graph."""
   1919     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
-> 1920     original_result = fn()
   1921     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1922     if len(post_summaries) > len(pre_summaries):

~/anaconda3/lib/python3.6/site-packages/keras/layers/normalization.py in normalize_inference()
    165                     broadcast_gamma,
    166                     axis=self.axis,
--> 167                     epsilon=self.epsilon)
    168             else:
    169                 return K.batch_normalization(

~/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
   1906             # so it may have extra axes with 1, it is not needed and should be removed
   1907             if ndim(mean) > 1:
-> 1908                 mean = tf.reshape(mean, (-1))
   1909             if ndim(var) > 1:
   1910                 var = tf.reshape(var, (-1))

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
   6294   if _ctx is None or not _ctx._eager_context.is_eager:
   6295     _, _, _op = _op_def_lib._apply_op_helper(
-> 6296         "Reshape", tensor=tensor, shape=shape, name=name)
   6297     _result = _op.outputs[:]
   6298     _inputs_flat = _op.inputs

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    785         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    786                          input_types=input_types, attrs=attr_protos,
--> 787                          op_def=op_def)
    788       return output_structure, op_def.is_stateful, op
    789 

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    486                 'in a future version' if date is None else ('after %s' % date),
    487                 instructions)
--> 488       return func(*args, **kwargs)
    489     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    490                                        _add_deprecated_arg_notice_to_docstring(

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in create_op(***failed resolving arguments***)
   3270           input_types=input_types,
   3271           original_op=self._default_original_op,
-> 3272           op_def=op_def)
   3273       self._create_op_helper(ret, compute_device=compute_device)
   3274     return ret

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1788           op_def, inputs, node_def.attr)
   1789       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1790                                 control_input_ops)
   1791 
   1792     # Initialize self._outputs.

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1627   except errors.InvalidArgumentError as e:
   1628     # Convert to ValueError for backwards compatibility.
-> 1629     raise ValueError(str(e))
   1630 
   1631   return c_op

ValueError: Shape must be rank 1 but is rank 0 for 'input_batchnorm/cond/Reshape_4' (op: 'Reshape') with input shapes: [1,256,1,1], [].