johnhw / differentiable_sorting

Differentiable bitonic sorting
MIT License
138 stars 9 forks source link

Feature Request: Configurable comparator function #2

Closed Ghost---Shadow closed 4 years ago

Ghost---Shadow commented 4 years ago

I was trying to build a differentiable sorter myself and its great to see such a high quality existing work.

I see that this currently only compares scalar values. I was planning to compare arbitrary shaped tensors like images or waveforms. What do you suggest would be the best approach to tackle this problem? Are you aware of any existing literature that I can refer to?

johnhw commented 4 years ago

Hi Ghost--Shadow. I have added comparison_sort() to do what you suggest. Usage:

matrices = bitonic_matrices(n)
comparison_sort(matrices, tensor_to_sort, compare_fn)

compare_fn should return signed scores (e.g. in the range [-1, 1]). The axis to be sorted by should be the first axis, so tensor_to_sort should have shape [n, ...]

There is a brief example in notebooks/Variations.ipynb

Ghost---Shadow commented 4 years ago

Yes. Something like that. Infact, I made an example on differentiable bubble sort with a DNN as the comparator function https://github.com/Ghost---Shadow/differentiable-programming-examples/blob/master/bubble-sort.ipynb

I am finding it difficult to find resources on differentiable programming. Are you aware of any books/guides/youtube channels/papers which may help me?

Also, checkout some of the other notebooks I made. You might find it interesting. https://github.com/Ghost---Shadow/differentiable-programming-examples

johnhw commented 4 years ago

I'm afraid I don't really have good resources. Some go to papers are:

I'm not aware of a good collections of tricks, primitives, and so on.

Thanks for the notebook links, these are very nice demos. Have you considered using JAX for simpler demos? The code can get out of your way more than the TF examples...

Ghost---Shadow commented 4 years ago

Once, I have enough content, I think I should make an open source book, something like "Introduction to differentiable computing".

I plan to use these as loss functions to train neural networks. So, it has to be a neural network library, tensorflow or torch.