Closed sumeshthakr closed 3 years ago
Hi @sumeshthakr, Thank you for your interest. Just as you said, the current script creates a set of graphs first and let the GNN layer to choose one to operate on [Graph_gen -> GNN]
. To create graphs dynamically, we need a loop like [Graph_gen -> GNN -> Graph_gen -> GNN -> Graph_gen...]
and we need to backpropagate through the Graph_gen
during training. One option is to insert a TF version Graph_gen (e.g. knn) between GNN layers:https://github.com/WeijingShi/Point-GNN/blob/48f3d79d5b101d3a4b8439ba74c92fcad4f7cab0/models/models.py#L119
Thanks for the reply,
By "inserting a tf version graph gen" you mean adding the graph generation configurations, between the GNN layers config ?
Hi @sumeshthakr, I mean we want to add a graph generation function between GNN layers, Unfortunately, the current graph generation function is not written in TF1.0, and therefore adding them directly won't work. You might need to write your desired tf function.
Thanks, Weijing,
Can you point me to some resources, depicting graph generation methods in tf ? I am sorry, any help will be really appreciated Thanks
Hi @sumeshthakr, the graph generation method would be a design choice. For example, we may want to use the cosine distance between previous GNN's point features to compute a new radius neighbor graph or use the point coordinates+predicted offset to compute a new graph. We just want the method written in tf1's api so the gradients are taken care of. I have two radius-neighbor graph methods in tf1.0:
def tf_brute_rnn(points_xyz, centers_xyz, r):
points_xyz_norm = tf.linalg.norm(points_xyz, axis=-1) ** 2
points_xyz_norm_expand = tf.expand_dims(points_xyz_norm, axis=1)
centers_xyz_norm = tf.linalg.norm(centers_xyz, axis=-1) ** 2
distance = centers_xyz_norm - 2 * tf.matmul(points_xyz, tf.transpose(centers_xyz)) + points_xyz_norm_expand
neighbors = tf.where(distance <= r ** 2)
return neighbors
def tf_rnn(points_xyz, centers_xyz, radius):
"""Tensorflow radius nearest neighbors search """
# find grid assignment, note the grid is padded with one cell on each edge.
points_origin = tf.reduce_min(points_xyz, axis=0, keepdims=True)
centers_origin = tf.reduce_min(centers_xyz, axis=0, keepdims=True)
origin = tf.minimum(points_origin, centers_origin)
points_cell = tf.cast((points_xyz - origin) // radius, tf.int32) + 1
centers_cell = tf.cast((centers_xyz - origin) // radius, tf.int32) + 1
max_point_cell = tf.reduce_max(points_cell, axis=0, keepdims=True)
max_center_cell = tf.reduce_max(centers_cell, axis=0, keepdims=True)
grid_size = tf.maximum(max_point_cell, max_center_cell) + 2
# generate cell indices for points
cell_index_base = tf.math.cumprod(grid_size, exclusive=True, axis=-1)
points_cell_index = tf.reduce_sum(points_cell * cell_index_base, axis=-1)
# sort points and find the (start, count) position of each cell
# in the sorted order
points_cell_order = tf.argsort(points_cell_index)
points_cell_index = tf.gather(points_cell_index, points_cell_order)
unique_cell_index, _, count = tf.unique_with_counts(points_cell_index)
unique_cell = tf.math.floordiv(tf.expand_dims(unique_cell_index, axis=1),
cell_index_base)
unique_cell = tf.math.floormod(unique_cell, grid_size)
unique_cell_start = tf.cumsum(count, exclusive=True)
# put (start, count) on the grid
grid_start = tf.scatter_nd(unique_cell, unique_cell_start, grid_size[0])
grid_count = tf.scatter_nd(unique_cell, count, grid_size[0])
# search neighbor cells
neighbors_cell_offset = tf.constant([[
[-1, -1, -1], [ 0, -1, -1], [ 1, -1, -1],
[-1, 0, -1], [ 0, 0, -1], [ 1, 0, -1],
[-1, 1, -1], [ 0, 1, -1], [ 1, 1, -1],
[-1, -1, 0], [ 0, -1, 0], [ 1, -1, 0],
[-1, 0, 0], [ 0, 0, 0], [ 1, 0, 0],
[-1, 1, 0], [ 0, 1, 0], [ 1, 1, 0],
[-1, -1, 1], [ 0, -1, 1], [ 1, -1, 1],
[-1, 0, 1], [ 0, 0, 1], [ 1, 0, 1],
[-1, 1, 1], [ 0, 1, 1], [ 1, 1, 1],
]], dtype=tf.int32
)
neighbors_cell = tf.expand_dims(centers_cell, 1) + neighbors_cell_offset
neighbors_start = tf.gather_nd(grid_start, neighbors_cell)
neighbors_count = tf.gather_nd(grid_count, neighbors_cell)
centers_indices = tf.tile(tf.expand_dims(tf.range(tf.shape(centers_xyz)[0]), axis=1),
(1,27))
neighbors_start = tf.reshape(neighbors_start, [-1])
neighbors_count = tf.reshape(neighbors_count, [-1])
centers_indices = tf.reshape(centers_indices, [-1])
non_empty_cell = neighbors_count > 0
neighbors_start = tf.boolean_mask(neighbors_start, non_empty_cell)
neighbors_count = tf.boolean_mask(neighbors_count, non_empty_cell)
centers_indices = tf.boolean_mask(centers_indices, non_empty_cell)
# fetch neighbor points
total_num_neighbors = tf.reduce_sum(neighbors_count)
neighbors_pos = tf.expand_dims(tf.cumsum(neighbors_count, exclusive=True), axis=1)
neighbors_index = tf.ones([total_num_neighbors], dtype=tf.int32)
neighbors_index = neighbors_index + tf.scatter_nd(neighbors_pos, neighbors_start-1, [total_num_neighbors])
neighbors_index = neighbors_index - tf.scatter_nd(neighbors_pos[1:], neighbors_start[:-1], [total_num_neighbors])
neighbors_index = neighbors_index - tf.scatter_nd(neighbors_pos[1:], neighbors_count[:-1]-1, [total_num_neighbors])
neighbors_index = tf.cumsum(neighbors_index)
neighbors = tf.gather(points_cell_order, neighbors_index)
centers = tf.scatter_nd(neighbors_pos, centers_indices, [total_num_neighbors])
centers = centers - tf.scatter_nd(neighbors_pos[1:], centers_indices[:-1], [total_num_neighbors])
centers = tf.cumsum(centers)
# filtering neighbors outside radius
neighbor_xyz = tf.gather(points_xyz, neighbors)
neighbor_center_xyz = tf.gather(centers_xyz, centers)
dist = tf.reduce_sum((neighbor_xyz - neighbor_center_xyz)**2, axis=-1)
inside_mask = dist <= radius**2
neighbors = tf.boolean_mask(neighbors, inside_mask)
centers = tf.boolean_mask(centers, inside_mask)
return tf.stack([neighbors, centers], axis=1)
They are not quite optimized and I found them slow comparing to sklearn implementation. Hope they are helpful to you.
Thanks a lot @WeijingShi
No problem.
Thanks for sharing the codebase, The script is very well written and easy to follow, As I followed your graph_gen and train script, I can see that the graphs are constructed at different levels, and then GNN layers choose specific graphs to process or aggregate, is it possible to create graphs dynamically at each layer?