offchan42 / superkeras

:rocket: A bunch of Keras utilities and paper implementations written under TensorFlow backend
GNU General Public License v3.0
10 stars 7 forks source link

Training a permutation invariant model to learn set partitions #5

Open BRO-HAMMER opened 4 years ago

BRO-HAMMER commented 4 years ago

@off99555 first of all, thank you for the implementation of the permutational module, it´s very cool.

I'm trying to find a way to build a model that could learn to find partitions within a set of data points in a supervised way. The criteria of how to group the elements has to be learned from the training data that I could represent as a boolean adjacency matrix, for example (where A[i, j] = 1 means element i and element j are in the same group).

So far, the best way I found to model the problem was treating the set as a sequence, feeding it to a Bidirectional RNN (many to many) and then predicting each bit of the adjacency matrix as a binary classification problem (using a Dense layer with sigmoid activation). But there isn't really a good way of ordering the points (similar to the point cloud problem), so I don't really like this approach. The number of points in the set and the number of groups they could form are variable, and there could even be outliers that don't belong to any group. I'm keeping the input and output size constant by using 0 padding.

I'm wondering if this module would make more sense for this task... in that case I guess I would need to call the module like this:

perm = PermutationalModule( (features,), n_objects, [repeat_layers(Dense, [n_cells1, n_cells2], activation="relu"), repeat_layers(Dense, [n_objects], activation="sigmoid")] )

And each output should be compared against the row of the adjacency matrix of it's input using binary crossentropy... does it make sense? I'd appreciate any help.

Thank you!

offchan42 commented 4 years ago

You mentioned point cloud, you can also use PointNet in this repo to train your model. PointNet is more efficient than Permutational Module if you have lots of features.