teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
765 stars 33 forks source link

Weighted soft_rank #76

Open davips opened 1 year ago

davips commented 1 year ago

A weighted soft_rank would be a great addition to have! It could be weighted by element index, by a weighting function, or by a vector of weights.

My initial attempt was this (which is part of a larger code to calculate a weighted spearman rho based on wcorr package):

def wsrank(x, w, regularization="l2", regularization_strength=1.):
    """
    >>> import torch
    >>> soft_rank(torch.tensor([[1., 2, 2, 3]]))
    tensor([[1.5000, 2.5000, 2.5000, 3.5000]])
    >>> wsrank(torch.tensor([[1., 2, 2, 3]]), torch.tensor([1., 1, 1, 1]))
    tensor([1.5000, 2.5000, 2.5000, 3.5000])
    >>> wsrank(torch.tensor([[1., 2, 3, 4]]), torch.tensor([1., 1/2, 1/3, 1/4]))
    tensor([1.0000, 1.5000, 1.8333, 2.0833])
    >>> wsrank(torch.tensor([[1., 2, 3, 4]], requires_grad=True), torch.tensor([1., 1/2, 1/3, 1/4])).sum().backward()
    """
    r = soft_rank(x, regularization=regularization, regularization_strength=regularization_strength).view(x.shape[1])
    d = hstack([r[0], diff(r)])
    s = cumsum((d * w) / 1, dim=0)
    return s

However it seems to be too optimistic (near 1.0) when compared to, e.g., weightedtau (around 0.6 in a random test I did here). The original README's soft-spearman works fine, being just a little more optimistic (~5% in some tests) than its hard counterpart, which makes sense to me.

davips commented 1 year ago

A soft, and possibly weighted, kendall-tau B would also be a great thing to have. It probably doesn't need sorting or ranking, just a sigmoid-like function to soften the agreement/disagreement/tie counters. We can take advantage of the GPU implementing a parallelization of the naive O(n²), instead of the Knight O(nlogn) algorithm adopted by S. Vigna in scipy/cython implementation.

teddykoker commented 1 year ago

Thanks for the suggestions @davips. Do you have any references for the weighted rank? Its not exactly clear to me the intended output of the function (even in an exact/"hard" setting).

I agree that a soft kendall-tau B would be great to have, although it may be out of scope of this library. I believe this could likely be done, as you said, without any sorting or ranking. I'm thinking you'd be able to do this with just plain PyTorch by constructing an n x n pairwise difference matrix for each variable, multiplying by some regularization value (so differences $\to {-\infty, 0, \infty}$ as regularization term $\to \infty$), then handling the concordant/discordant pairs with activation functions. Definitely feasible with differentiable operations, but I'm not sure how "smooth" it would be for low regularization values. Not sure if I have the time to implement this myself, but I'd be happy to iterate on any ideas you might have.

davips commented 11 months ago

@teddykoker , your suggestion is pretty similar to what I am trying now. If we find a common ground, I can contribute to the library. I often use python packages as a mean to organize work done and make it available to future work, so when I have some free time it is certainly possible.

The weighting scheme is simple. Higher ranks (i.e. lower values) have higher weights defined by a custom function which could be anything, e.g., cauchy or harmonic progression (not as good as a distribution, but it was recommended by Vigna in weighted tau, more details later).

This is my soft_kendalltau so far:

from torch import sigmoid, sum, triu_indices

def pdiffs(x):
    dis = x.unsqueeze(1) - x
    indices = triu_indices(*dis.shape, offset=1)
    return dis[indices[0], indices[1]]

def surrogate_tau(a, b, reg=10):
    da, db = pdiffs(a), pdiffs(b)
    return sum(sigmoid(reg * da * db))

soft_kendalltau = surrogate_tau(pred, target) / (len(pred) * (len(pred) - 1) / 2)

This is my soft_weigthed_kendalltau so far:

from torch import sigmoid, sum, triu_indices

def pdiffs(x):
    dis = x.unsqueeze(1) - x
    indices = triu_indices(*dis.shape, offset=1)
    return dis[indices[0], indices[1]]

def psums(x):
    dis = x.unsqueeze(1) + x
    indices = torch.triu_indices(*dis.shape, offset=1)
    return dis[indices[0], indices[1]]

def surrogate_wtau(a, b, w, reg=10):
    da, db, sw = pdiffs(a), pdiffs(b), psums(w)
    return sum(sigmoid(reg * da * db) * sw) / sum(sw)

soft_weigthed_kendalltau = surrogate_wtau(pred, target, weights)

# Missing part: the weights are defined according to the rank of each value in `target`.
#               Such a ranking depends on `soft_rank`.
#               I will try with `cau = scipy.stats.cauchy(0).pdf`:
#               w = tensor([cau(r) for r in soft_rank(target)], requires_grad=True)
#               Perhaps this will work fine.

I am not experienced with regularization. I assume that lower reg values make it softer and easier for the gradient to descend, at the expense of having a less reliable approximation to the real (hard) concept.

The paper on weighted tau doesn't seem trivial to follow (at a first glance), and the implementation is so optimized that it looks more complicated than what the problem really is: scipy weighted tau

In a nut shell: you may help me taking a look the soft_kendalltau and soft_weigthed_kendalltau code above to assess if it makes sense, and to decide if it is in a possible scope of the library.

davips commented 11 months ago

A _Weighted softrank implementation would be a more general addition to the library, that could be used to replace parts of the code I wrote above. However, too trivial, as it seems to be just soft_rank(x) * weigher(soft_rank(x)).