Open utterances-bot opened 3 years ago
Hey, I believe your backward function is NOT numerical stable.. exp(a[n][row][col] + b[n][col][k] - part[n][row][k]) can be troublesome. And I saw the original code from harvardnlp/genbmm, they seems to save the maximum value from forward pass to overcome this issue.
@sustcsonglin Thanks! That's really useful to point out. I was definitely going off an older version before they updated that part
I think rather than storing the maxes, it's probably better to used signed log-gradients for numerical stability in the backward pass (which avoids additional reads from GPU memory). Probably won't update the post, but here's a small library I wrote to handle signed log operations.
/**
* Defines some helper functions for signed log-space operations.
*
* As a convention, true = positive and false = negative. Casting
* unsigned to signed works the same way as with base types (assumed
* to be a positive value, already in log space).
*
* sign: gets the sign from the `signed_t`
* val: gets the absolute val from the `signed_t`
* set_: sets a = b
* slog: gets the signed log of any number
* sexp: signed exponent
* sum: adds together two values in signed log-space
* mul: multiplies together two values in signed log-space
*/
template <typename scalar_t>
struct signed_t {
bool sign;
scalar_t val;
signed_t(bool b, scalar_t v) : sign(b), val(v) {}
};
template <typename scalar_t>
void set_(signed_t<scalar_t>& a, const signed_t<scalar_t>& b) {
a.sign = b.sign;
a.val = b.val;
}
template <typename scalar_t>
signed_t<scalar_t> slog(const scalar_t& x) {
return signed_t<scalar_t>(x >= 0, log(abs(x)));
}
template <typename scalar_t>
scalar_t sexp(const signed_t<scalar_t>& x) {
return x.sign ? exp(x.val) : -exp(x.val);
}
template <typename scalar_t>
signed_t<scalar_t> sum(
const signed_t<scalar_t>& a,
const signed_t<scalar_t>& b) {
if (b.val > a.val)
return sum(b, a);
if (a.val == NEG_INF)
return {a.sign, NEG_INF};
if (a.val == POS_INF)
return {a.sign, POS_INF};
const scalar_t d = a.sign ^ b.sign ? -exp(b.val - a.val) : exp(b.val - a.val);
return signed_t<scalar_t>(a.sign, a.val + log(1 + d));
}
template <typename scalar_t>
signed_t<scalar_t> sum(const signed_t<scalar_t>& a, const scalar_t& b) {
return slog::sum(a, {true, b});
}
template <typename scalar_t>
signed_t<scalar_t> sum(const scalar_t& a, const signed_t<scalar_t>& b) {
return slog::sum({true, a}, b);
}
template <typename scalar_t>
signed_t<scalar_t> sum(const scalar_t& a, const scalar_t& b) {
return slog::sum({true, a}, {true, b});
}
template <typename scalar_t>
signed_t<scalar_t> mul(
const signed_t<scalar_t>& a,
const signed_t<scalar_t>& b) {
if (a.val == NEG_INF)
return b;
if (b.val == NEG_INF)
return a;
return signed_t<scalar_t>(!(a.sign ^ b.sign), a.val + b.val);
}
template <typename scalar_t>
signed_t<scalar_t> mul(const signed_t<scalar_t>& a, const scalar_t& b) {
return signed_t<scalar_t>(a.sign, a.val + b);
}
template <typename scalar_t>
signed_t<scalar_t> mul(const scalar_t& a, const signed_t<scalar_t>& b) {
return signed_t<scalar_t>(b.sign, b.val + a);
}
template <typename scalar_t>
signed_t<scalar_t> mul(const scalar_t& a, const scalar_t& b) {
return signed_t<scalar_t>(true, a + b);
}
template <typename scalar_t>
constexpr signed_t<scalar_t> zero() {
return signed_t<scalar_t>(true, NEG_INF);
}
Optimized Log-Sum-Exp PyTorch Function | Ben Bolte's Blog
https://ben.bolte.cc/logsumexp