johnhw / differentiable_sorting

Differentiable bitonic sorting
MIT License
138 stars 9 forks source link

Sorting 3D or 4D vectors #3

Open zimonitrome opened 2 years ago

zimonitrome commented 2 years ago

I love this library and got it working well with pytorch:

# 2D tensor
input_tensor = torch.tensor([
        [1, 5], 
        [30, 30], 
        [6, 9], 
        [80, -2],
]).float()

mask = torch.tensor([1, 0]).float()

vector_sort(
    bitonic_matrices(4),
    input_tensor,
    lambda x: x @ mask, # sort by column 1    
    alpha=1.0
)

> tensor([
        [80.0000, -2.0000],
        [30.0000, 30.0000],
        [ 5.9665,  8.9732],
        [ 1.0335,  5.0268]
])

But I am now trying to to extend this to higher dimensions:

# 3D tensor
input_tensor = torch.tensor([
    [
        [1, 5], 
        [30, 30], 
        [6, 9], 
        [80, -2]
    ],
                [
        [2, 6], 
        [31, 31], 
        [7, 10], 
        [81, -1]
    ],
]).float()

target_tensor = torch.tensor([
    [
        [80, -2],
        [30, 30], 
        [6, 9], 
        [1, 5], 
    ],
    [
        [81, -1],
        [31, 31], 
        [7, 10], 
        [2, 6], 
    ],
])

mask = torch.tensor([1, 0]).float()

vector_sort(
    bitonic_matrices(8),
    input_tensor,
    lambda x: x @ mask,
    alpha=1.0
)

But I receive the error:

~\anaconda3\lib\site-packages\differentiable_sorting\torch\differentiable_sorting_torch.py in vector_sort(matrices, X, key, alpha)
     85         x = key(X)
     86         # compute weighting on the scalar function
---> 87         a, b = l @ x, r @ x
     88         a_weight = torch.exp(a * alpha) / (torch.exp(a * alpha) + torch.exp(b * alpha))
     89         b_weight = 1 - a_weight

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x8 and 2x4)

How would I go about extending the sorting to work for 3D or nD tensors?

johnhw commented 2 years ago

This is probably too late to be useful (sorry!), but you'd need to either: (a) Unravel your data, sort it, and then reshape it back into the nD tensor, if what you want is just to sort the elements independently. (b) Define a (differentiable) comparator function if you want to use the tensor structure in the sorting, and then call comparison_sort(matrices, my_comparator) (e.g. you could sort the matrices by sum of rows using this method).

johnhw commented 2 years ago

Alternatively,

(c) if your sorting doesn't require a fully custom comparator, but you can instead map from some space (e.g. row vectors) to scalars (as in the row sum example), you could use vector_sort() with a key function which maps the input through the key and sorts on that.