codekansas / codekansas.github.io

:computer: Personal blog
https://ben.bolte.cc/
MIT License
3 stars 1 forks source link

logsumexp #3

Open utterances-bot opened 3 years ago

utterances-bot commented 3 years ago

Optimized Log-Sum-Exp PyTorch Function | Ben Bolte's Blog

https://ben.bolte.cc/logsumexp

sustcsonglin commented 3 years ago

Hey, I believe your backward function is NOT numerical stable.. exp(a[n][row][col] + b[n][col][k] - part[n][row][k]) can be troublesome. And I saw the original code from harvardnlp/genbmm, they seems to save the maximum value from forward pass to overcome this issue.

codekansas commented 3 years ago

@sustcsonglin Thanks! That's really useful to point out. I was definitely going off an older version before they updated that part

codekansas commented 3 years ago

I think rather than storing the maxes, it's probably better to used signed log-gradients for numerical stability in the backward pass (which avoids additional reads from GPU memory). Probably won't update the post, but here's a small library I wrote to handle signed log operations.

/**
 * Defines some helper functions for signed log-space operations.
 *
 * As a convention, true = positive and false = negative. Casting
 * unsigned to signed works the same way as with base types (assumed
 * to be a positive value, already in log space).
 *
 * sign: gets the sign from the `signed_t`
 * val: gets the absolute val from the `signed_t`
 * set_: sets a = b
 * slog: gets the signed log of any number
 * sexp: signed exponent
 * sum: adds together two values in signed log-space
 * mul: multiplies together two values in signed log-space
 */

template <typename scalar_t>
struct signed_t {
  bool sign;
  scalar_t val;
  signed_t(bool b, scalar_t v) : sign(b), val(v) {}
};

template <typename scalar_t>
void set_(signed_t<scalar_t>& a, const signed_t<scalar_t>& b) {
  a.sign = b.sign;
  a.val = b.val;
}

template <typename scalar_t>
signed_t<scalar_t> slog(const scalar_t& x) {
  return signed_t<scalar_t>(x >= 0, log(abs(x)));
}

template <typename scalar_t>
scalar_t sexp(const signed_t<scalar_t>& x) {
  return x.sign ? exp(x.val) : -exp(x.val);
}

template <typename scalar_t>
signed_t<scalar_t> sum(
    const signed_t<scalar_t>& a,
    const signed_t<scalar_t>& b) {
  if (b.val > a.val)
    return sum(b, a);
  if (a.val == NEG_INF)
    return {a.sign, NEG_INF};
  if (a.val == POS_INF)
    return {a.sign, POS_INF};
  const scalar_t d = a.sign ^ b.sign ? -exp(b.val - a.val) : exp(b.val - a.val);
  return signed_t<scalar_t>(a.sign, a.val + log(1 + d));
}

template <typename scalar_t>
signed_t<scalar_t> sum(const signed_t<scalar_t>& a, const scalar_t& b) {
  return slog::sum(a, {true, b});
}

template <typename scalar_t>
signed_t<scalar_t> sum(const scalar_t& a, const signed_t<scalar_t>& b) {
  return slog::sum({true, a}, b);
}

template <typename scalar_t>
signed_t<scalar_t> sum(const scalar_t& a, const scalar_t& b) {
  return slog::sum({true, a}, {true, b});
}

template <typename scalar_t>
signed_t<scalar_t> mul(
    const signed_t<scalar_t>& a,
    const signed_t<scalar_t>& b) {
  if (a.val == NEG_INF)
    return b;
  if (b.val == NEG_INF)
    return a;
  return signed_t<scalar_t>(!(a.sign ^ b.sign), a.val + b.val);
}

template <typename scalar_t>
signed_t<scalar_t> mul(const signed_t<scalar_t>& a, const scalar_t& b) {
  return signed_t<scalar_t>(a.sign, a.val + b);
}

template <typename scalar_t>
signed_t<scalar_t> mul(const scalar_t& a, const signed_t<scalar_t>& b) {
  return signed_t<scalar_t>(b.sign, b.val + a);
}

template <typename scalar_t>
signed_t<scalar_t> mul(const scalar_t& a, const scalar_t& b) {
  return signed_t<scalar_t>(true, a + b);
}

template <typename scalar_t>
constexpr signed_t<scalar_t> zero() {
  return signed_t<scalar_t>(true, NEG_INF);
}