Implemented a logsumexp2 which takes two arguments. See logaddexp for torch version 1.7+.
Updated the gumbel intersection to use logsumexp2/logaddexp. This makes it work with automatic broadcasting.
Performed random testing for numerical stability and for checking if the theoretical inequalities: max(a,b) < beta*logaddexp(a/beta, b/beta) and min(a,b) > -beta*logaddexp(-a/beta, -b/beta) hold. Currently, the use of torch.nextafter on logaddexp seem to be enough for this.
Implemented a logsumexp2 which takes two arguments. See logaddexp for torch version 1.7+.
Updated the gumbel intersection to use logsumexp2/logaddexp. This makes it work with automatic broadcasting.
Performed random testing for numerical stability and for checking if the theoretical inequalities:
max(a,b) < beta*logaddexp(a/beta, b/beta)
andmin(a,b) > -beta*logaddexp(-a/beta, -b/beta)
hold. Currently, the use oftorch.nextafter
onlogaddexp
seem to be enough for this.