harvardnlp / genbmm

CUDA kernels for generalized matrix-multiplication in PyTorch
79 stars 13 forks source link

[feature request] support log-bmm to context-free grammars #6

Open sustcsonglin opened 4 years ago

sustcsonglin commented 4 years ago

I found log-bmm very useful for linear-chain CRF to save memory and speed up, while in context-free grammars, A->BC requires amounts of GPU memories, which is more serious. So it is difficult to increase the number of non-terminals or terminals under single graphical-card situation.

srush commented 4 years ago

It should just work? Is there a bug?

sustcsonglin commented 4 years ago

I am afraid that it does not work in CFGs. logbmm can only pass two tensors, and the dimension of tensors is 3.

For example, in compound PCFG, we have rule A->BC for each sentence, the size is (B, NT, NT, NT) (we only consider non-terminal here), the shape for B and C is (B, n-w, w, NT) (we are using linear-scan here).

If we want to get the final A, which has the shape (B, n-w, NT), It seems that we have to create a temporary tensor (B, n-w, w, NT, NT) to combine B and C, then reshape it to (B n-w w, NTNT, 1). the grammar A->BC need to expand to (B,n-w w, NT, NTNT) to apply logbmm function., in this case we still need memory(B, n-w w, NT, NT, NT) since expanding A->BC not works because logbmm needs contiguous tensor. both n-w w and NT^3 can be very huge.

I believe there is an inherent difference between CFGs and linear-chain models. The most ideal situation is that we can directly combine B, C and A->BC to A. (B, n-w, w, NT) + (B, n-w, w, NT) + (B, NT, NT, NT) -> (B, n-w, w, NT) -> (B, n-w, NT) without any bigger intermediate tensors than (B, n-w, w, NT)

srush commented 4 years ago

I see what you mean. So you are suggesting writing another intermediate operator that directly does both combinations without storing intermediates.

Let's do this together. Maybe you can give a minimal suggestion of what that operator would need to look like?

One idea would be to support the operators in einsum https://pypi.org/project/opt-einsum/ directly?

srush commented 4 years ago

Or perhaps you are just suggesting that genbmm should support broadcasting along the first dimension? Would that work? (A->BC could be size (1, NT, NT * NT) and still be contiguous without ever explicitly expanding right? )

srush commented 4 years ago

I like the second solution better. If you are motivated to give it a try, here's how to do it.

1) edit this line so you check both the size of a.size(0) and b.size(0) : https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L373

2) set the block size based on the max of the two https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L385

3) Pass the batch sizes into this function https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L398

4) Instead of doing n here make an n_a and n_b variables that are always 0 if size is 1 https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L33

5-9) Do the same for the backward version of the function https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L437 https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L129

Test that it works here: https://github.com/harvardnlp/genbmm/blob/master/genbmm/test_cuda.py#L12

sustcsonglin commented 4 years ago

I found a slightly better way to reduce O(batch, n-w, w, A, B, C) to O(batch, n-w, A, B, C)

Instead of combing B and C first, we can combine the grammar rule A->BC and C first, (batch, AB, C) + (batch, C, n-ww) -> (batch, AB, n-ww) -> (batch n-w, AB, w) + (batch n-w, w, C) -> (batch n-w, AB, C) -> (batch, n-w, A, BC) -> (batch, n-w, A)

But it still suffered from O(NT)^3.

I think previous is attracting. if logbmm can take three arguments: grammars (batch, A, B, C), left (batch, n-w, w, B), right (batch n-w, w, C). and design a kernel to support final[:, :, :, k] = logsumexp left[:, :, :, i] + right[:, :, :, j] + grammars[:, k, i, j] will be great, we only need o(batch, n-w, w, A) memory in this situation.

I found a similar library “Keops" for lazy reduction and supporting logsumexp, but i did not try yet.

srush commented 4 years ago

Cool, yeah I played with keops a bit but it didn't performs as well as I would have liked. (see https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/keops.py ) . But I think maybe that was because I was trying to only do binary reductions. If you think the key here is a triple reduction, you should definitely try it out.

srush commented 4 years ago

Btw, does this same issue appear for dependency parsing? It would be nice to have a kernel that wasn't so CFG specific.

sustcsonglin commented 4 years ago

no, it is not an issue for dependency parsing since dependency parsing does not have "non-terminals". Dependency parsing can be regarded as lexicalized CFGs with non-terminals is Null for dependency parsing and valence number of dependency model with valence (DMV). do you suggest that I should modify your binary reduction kernel to triplet? I have no cuda programming experience, is it very difficult?

srush commented 4 years ago

Oh no, you should definitely not try to do triplets in CUDA that would be really messy.

I think the right way to do this is to remove this expansion function https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/fast_semirings.py#L19 and instead implement binary broadcasting in cuda correctly. That way you just never create the bigger (A->BC) tensor.

However I think keops is worth exploring as well. You might consider just copying my code and trying out keops directly without any of the semiring stuff.

sustcsonglin commented 4 years ago

Thank you, i'll have a try

sustcsonglin commented 4 years ago

btw, i found the autograd of pytorch uses amounts of gpu memories to calculate gradient. if I use linear-scan to explicitly implement the outside algorithm and use inside-outside algorithm to compute the gradient, it saves 10x gpu memories and 1.5x faster, but it is annoying to implement outside-algorithm manually for each algorithms. Do you have any ideas for combining the advantage of both of them?

srush commented 4 years ago

Which algorithm are you talking about particularly? Also what do you mean by linear scan here? I don't use linear-scan for any of the tree approaches.

I started by implementing backward manually it didn't actually make things much faster and it was difficult to use different semirings.

sustcsonglin commented 4 years ago

eisner, zero-order cky and pcfgs and so on

I refer linear-scan as O(n) implementation here (considering all spans with the same width at the same time). I am trying to combine the genbmm.logbmm function to outside algorithm now, it seems to reduce around 30x memories compares to your CKY_CRF implementation, which does not use logbmm and use autograd to calculate gradients.

srush commented 4 years ago

Hmm, would be curious to know how CKY_CRF with logbmm compares to manual backward. Not sure where it is storing so much extra memory.

sustcsonglin commented 4 years ago

In cky_crf batch=10, length=50, NT=25, T=25, logbmm + autograd saves 10x memory logbmm + inside-outside saves 80x memory, i have to re-compute many terms to save the space, it is a trade-off between speed and space, but it is not too slow, only 1.5x slower than the original implementation. While in zero-order cky, there is no need to re-compute anything, so I can reach 1.5x faster if I use inside-outside algorithm.

srush commented 4 years ago

One really nice trick to save memory (without more code) is by recomputing is to use Checkpointing. It basically just automatically reruns forward for you.

Here is an example of that: https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/checkpoint.py

srush commented 4 years ago

Interesting, I did try this originally, but it is not as precise. I got some underflow errors.

On Fri, Aug 14, 2020 at 8:06 AM sustcsonglin notifications@github.com wrote:

i found that we can make use of highly-optimized matrix multiplications in exp space (instead of log space)

i can simply write an alternative to logbmm:

def logsumexp(a,b): m = torch.max(torch.max(a), torch.max(b)) a1 = (a-m).exp() b1 = (b-m).exp() c = torch.matmul(a1, b1) return c.log().add_(2*m)

it uses less memory than "logbmm", we can make use of the highly optimized operation "torch.matmul" here.

Similarly, we can make use of library "opt_einsum" to handle the triplet situation,

def logsumexp_V2(a, b, c):

shape of a (b, n, w ,Y) left span

# shape of b (b, n, w, Z)   right span
# shape of c (b, X, Y, Z, )   grammar rules.

ma = torch.max(a)
mb = torch.max(b)
mc = torch.max(c)
m = torch.max(torch.max(ma, mb), mc)
a1 = (a-m).exp()
b1 = (b-m).exp()
c1 = (c-m).exp()
res = contract("bnwy, bnwz, bxyz -> bx", a1,b1,c1 , backend='torch')
return res.log().add_(3*m)

it saves amounts of memory, and my issues have been solved.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/harvardnlp/genbmm/issues/6#issuecomment-674042547, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAIYKQTZJRVGJ6ZKY3IU53SAUSEJANCNFSM4PY6W5KQ .

sustcsonglin commented 4 years ago

me too, I am going to do as follows: https://stackoverflow.com/a/52916131

srush commented 4 years ago

Yeah unfortunately I tried a bunch of these tricks and ended up realizing it is hard to beat the CUDA version. But I love that you are looking into this! Would be fantastic if there was a good trick here.

sustcsonglin commented 4 years ago

i found pieces of codes from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/hmm.py, and replace it with logbmm, it saves 10% memories, but it is around 2x slower than CUDA version

`class _SafeLog(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.log()

    @staticmethod
    def backward(ctx, grad):
        x, = ctx.saved_tensors
        return grad / x.clamp(min=torch.finfo(x.dtype).eps)

def safe_log(x):
    """
    Like :func:`torch.log` but avoids infinite gradients at log(0)
    by clamping them to at most ``1 / finfo.eps``.
    """
    return _SafeLog.apply(x)

def _logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    finfo = torch.finfo(x.dtype)  # avoid nan due to -inf - -inf
    x_shift = x.max(-1, keepdim=True).values.clamp(min=finfo.min)
    y_shift = y.max(-2, keepdim=True).values.clamp(min=finfo.min)
    xy = (torch.matmul((x - x_shift).exp(), (y - y_shift).exp())).log()
    return xy + x_shift + y_shift
srush commented 4 years ago

Isn't this code the same as from the stackoverflow above?

On Fri, Aug 14, 2020 at 1:38 PM sustcsonglin notifications@github.com wrote:

i found pieces of codes from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/hmm.py, and replace it with logbmm, it saves 10% memories, but it is around 2x slower than CUDA version

`class _SafeLog(torch.autograd.Function): @staticmethod https://github.com/staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x.log()

@staticmethod def backward(ctx, grad): x, = ctx.saved_tensors return grad / x.clamp(min=torch.finfo(x.dtype).eps)

def safe_log(x): """ Like :func:torch.log but avoids infinite gradients at log(0) by clamping them to at most 1 / finfo.eps. """ return _SafeLog.apply(x)

def _logmatmulexp(x, y): """ Numerically stable version of (x.log() @ y.log()).exp(). """ finfo = torch.finfo(x.dtype) # avoid nan due to -inf - -inf x_shift = x.max(-1, keepdim=True).values.clamp(min=finfo.min) y_shift = y.max(-2, keepdim=True).values.clamp(min=finfo.min) xy = (torch.matmul((x - x_shift).exp(), (y - y_shift).exp())).log() return xy + x_shift + y_shift `

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/harvardnlp/genbmm/issues/6#issuecomment-674185458, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAIYKTS4IZU523FLZB5IVLSAVZBFANCNFSM4PY6W5KQ .