google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

Shape (5953500,) must have rank >= 2 for input to model.EncodeProcessDecode() #130

Closed joshneudorf closed 3 years ago

joshneudorf commented 3 years ago

Hello, I am excited to make use of this great library for graph neural networks. However, I am running into an issue where I get an error ValueError: Shape (5953500,) must have rank >= 2 when passing an input to models.EncodeProcessDecode. The input is a graphs_tuple converted from a networkx graph using utils_np.networkxs_to_graphs_tuple (no global values, node features have all been set to 1.0, and the edges have values from a 90 x 90 adjacency matrix).

For example:

r_graphs = utils_np.networkxs_to_graphs_tuple(rmatrix_nx)

where rmatrix_nx is a list of networkx graphs.

I am roughly following the tf2 sort example, doing something like this to set up the model and the update_step function:

# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 10
num_processing_steps_ge = 10

# Data / training parameters.
num_training_iterations = 3000
batch_size_tr = 32
batch_size_ge = 100

# Optimizer.
learning_rate = 1e-3
optimizer = snt.optimizers.Adam(learning_rate)

model = models.EncodeProcessDecode(edge_output_size=1, node_output_size=1)
last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []

# Training.
def update_step(inputs_tr, targets_tr):
    with tf.GradientTape() as tape:
        outputs_tr = model(inputs_tr, num_processing_steps_tr)#Shape (5953500,) must have rank >= 2
        # Loss.
        # Need to convert to numpy before using tri_mse function
        targets_tr_nx = utils_np.graphs_tuple_to_networkxs(targets_tr)
        targets_tr_matrix = nx.adjacency_matrix(targets_tr_nx,weight="features")
        outputs_tr_nx = utils_np.graphs_tuple_to_networkxs(outputs_tr)
        outputs_tr_matrix = nx.adjacency_matrix(outputs_tr_nx,weight="features")
        loss_tr = tri_mse(targets_tr_matrix, outputs_tr_matrix)

    gradients = tape.gradient(loss_tr, model.trainable_variables)
    optimizer.apply(gradients, model.trainable_variables)
    return outputs_tr, loss_tr

Then when I run training steps (x_train, y_train, x_val, and y_val are training and validation subsets' input and output):

##Run training steps
# You can interrupt this cell's training loop at any time, and visualize the
# intermediate results by running the next cell (below). You can then resume
# training by simply executing this cell again.

# Instantiate the model.

# How much time between logging and printing the current results.
log_every_seconds = 20

print("# (iteration number), T (elapsed seconds), "
      "Ltr (training loss), Lge (test/generalization loss), "
      "Ctr (training fraction nodes/edges labeled correctly), "
      "Str (training fraction examples solved correctly), "
      "Cge (test/generalization fraction nodes/edges labeled correctly), "
      "Sge (test/generalization fraction examples solved correctly)")

start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
  last_iteration = iteration

  outputs_tr, loss_tr = update_step(x_train, y_train)

  the_time = time.time()
  elapsed_since_last_log = the_time - last_log_time
  if elapsed_since_last_log > log_every_seconds:
    last_log_time = the_time
    outputs_ge = model(x_val, num_processing_steps_ge)
    losss_ge = tri_mse(y_val, outputs_ge)
    loss_ge = losss_ge[-1]

    # Replace the globals again to prevent exceptions.
    outputs_tr[-1] = outputs_tr[-1].replace(globals=None)
    targets_tr = targets_tr.replace(globals=None)

    elapsed = time.time() - start_time
    losses_tr.append(loss_tr.numpy())
    corrects_tr.append(correct_tr)
    solveds_tr.append(solved_tr)
    losses_ge.append(loss_ge.numpy())
    corrects_ge.append(correct_ge)
    solveds_ge.append(solved_ge)
    logged_iterations.append(iteration)
    print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, "
          "Str {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
              iteration, elapsed, loss_tr.numpy(), loss_ge.numpy(),
              correct_tr, solved_tr, correct_ge, solved_ge))

I get this error:

ValueError                                Traceback (most recent call last)
<ipython-input-22-470244fb2b52> in <module>
     21   last_iteration = iteration
     22 
---> 23   outputs_tr, loss_tr = update_step(x_train, y_train)
     24 
     25   the_time = time.time()

<ipython-input-21-ef3b4aac418d> in update_step(inputs_tr, targets_tr)
     98 def update_step(inputs_tr, targets_tr):
     99     with tf.GradientTape() as tape:
--> 100         outputs_tr = model(inputs_tr, num_processing_steps_tr)#Shape (5953500,) must have rank >= 2
    101         # Loss.
    102         # Need to convert to numpy before using tri_mse function

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/graph_nets/demos_tf2/models.py in __call__(self, input_op, num_processing_steps)
    118 
    119   def __call__(self, input_op, num_processing_steps):
--> 120     latent = self._encoder(input_op)
    121     latent0 = latent
    122     output_ops = []

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/graph_nets/demos_tf2/models.py in __call__(self, inputs)
     54 
     55   def __call__(self, inputs):
---> 56     return self._network(inputs)
     57 
     58 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/graph_nets/_base.py in __call__(self, *args, **kwargs)
     76 
     77     def __call__(self, *args, **kwargs):
---> 78       return self._build(*args, **kwargs)
     79 
     80     @abc.abstractmethod

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/graph_nets/modules.py in _build(self, graph)
    364     """
    365     return graph.replace(
--> 366         edges=self._edge_model(graph.edges),
    367         nodes=self._node_model(graph.nodes),
    368         globals=self._global_model(graph.globals))

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/graph_nets/_base.py in __call__(self, *args, **kwargs)
     76 
     77     def __call__(self, *args, **kwargs):
---> 78       return self._build(*args, **kwargs)
     79 
     80     @abc.abstractmethod

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/graph_nets/_base.py in _build(self, *args, **kwargs)
    110 
    111   def _build(self, *args, **kwargs):
--> 112     return self._model(*args, **kwargs)

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/sequential.py in __call__(self, inputs, *args, **kwargs)
     70       if i == 0:
     71         # Pass additional arguments to the first layer.
---> 72         outputs = mod(outputs, *args, **kwargs)
     73       else:
     74         outputs = mod(outputs)

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/nets/mlp.py in __call__(self, inputs, is_training)
    100 
    101     for i, layer in enumerate(self._layers):
--> 102       inputs = layer(inputs)
    103       if i < (num_layers - 1) or self._activate_final:
    104         # Only perform dropout if we are activating the output.

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/linear.py in __call__(self, inputs)
     87 
     88   def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
---> 89     self._initialize(inputs)
     90 
     91     outputs = tf.matmul(inputs, self.w)

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/once.py in wrapper(wrapped, instance, args, kwargs)
     91 
     92     if once_id not in seen:
---> 93       _check_no_output(wrapped(*args, **kwargs))
     94       seen.add(once_id)
     95 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/linear.py in _initialize(self, inputs)
     65   def _initialize(self, inputs: tf.Tensor):
     66     """Constructs parameters used by this module."""
---> 67     utils.assert_minimum_rank(inputs, 2)
     68 
     69     input_size = inputs.shape[-1]

~/anaconda3/envs/graph_nets/lib/python3.8/site-packages/sonnet/src/utils.py in assert_minimum_rank(inputs, rank)
    148   actual_rank = len(shape)
    149   if actual_rank < rank:
--> 150     raise ValueError("Shape %r must have rank >= %d" % (shape, rank))
    151 
    152 

ValueError: Shape (5953500,) must have rank >= 2

When I print the input graph_tuples:

GraphsTuple(nodes=array([1., 1., 1., ..., 1., 1., 1.]), edges=array([1.        , 0.26318767, 0.20551166, ..., 0.28894122, 0.25962281,
       1.        ]), receivers=<tf.Tensor: shape=(5953500,), dtype=int32, numpy=array([    0,     1,     2, ..., 66147, 66148, 66149], dtype=int32)>, senders=<tf.Tensor: shape=(5953500,), dtype=int32, numpy=array([    0,     0,     0, ..., 66149, 66149, 66149], dtype=int32)>, globals=[], n_node=array([90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90], dtype=int32), n_edge=array([8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100,
       8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100, 8100], dtype=int32))

I hope this isn't too verbose, I just wanted to make sure I provide enough information to describe the problem. Thank you for any insight you may have into this problem.

alvarosg commented 3 years ago

Thank you for your message.

The EncodeProcessDecode uses MLPs as the node, edge and global update, so the input features are expected to be a vector for each node, edge and global. From the input GraphsTuples you are printing it seems that your nodes and edges have a single scalar feature per node and edge, instead of a vector of size one, so you may have to add that additional dimension.

You may do so either directly when you build the graphs (By setting the feature for each node/edge to be something like [1.], rather than just 1.) or at the graphs_tuple directly, by using something like:

graph_tuple= graph_tuple.replace(nodes=graph_tuples.nodes[:, None], 
                                                          edges=graph_tuples.edges[:, None])

both those solutions should make it so your input shape is something like:

graph_tuple.nodes.shape = [total_num_nodes, 1]
graph_tuple.edges.shape = [total_num_edges, 1]

And your output shape will also be like that, so you may want to squeeze that dimension before building your loss by doing the opposite:

graph_tuple= graph_tuple.replace(nodes=tf.squeeze(graph_tuples.nodes, axis=1), ...)

Note the total number of edges in your graph seems very high (5953500 to be precise, so you should probably drastically reduce your batch size, or you will run out of memory).

Also, I noticed that in your graph_tuple your globals are set to []. Not sure what the reason for this is, but you may want to use .replace(globals=None) somewhere else too, so you get proper error messages inside the model.

If you are not interested on using globals at all, you may want to replace the GraphNetwork by an InteractionNetwork in EncodeProcessDecode, or set a a dummy features in the globals using utils_tf.set_zero_global_features, so you can more easily connect theGraphNetwork` if you do want to have global updates.

Hope this helps!

joshneudorf commented 3 years ago

Thank you very much for your detailed response. graph_tuple= graph_tuple.replace(nodes=graph_tuples.nodes[:, None], edges=graph_tuples.edges[:, None]) worked to fix the issue with Rank. I have also made sure to squeeze that dimension in my loss function, have switched globals to None, and as you said, I am running out of memory so I will try reducing the batch size to fix this.

Thank you again for your response!