Closed speedcell4 closed 1 year ago
Thank you for your quick response.
I think now it is free to train on arbitrary length sentences, as current implementation is optimized.
Could you please briefly describe what optimization has been applied to your current implementation? I noticed you re-implemented the inside function with trion
. Is there anything else important?
the main speedup comes from the log-einsum-exp trick
. You can also refer to my pure pytorch implementation to get the idea of this trick. In my triton implementation, I do kernel fusion to reduce IO cost in the spirit of FlashAttention.
Sorry, I don't get it. What do you mean by log-einsum-exp trick
? is that the log-sum-exp trick? To my knowledge, this is for avoiding overflow, not for accelerating. Could you please describe more details on this?
log-einsum-exp trick
introduced from Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits
https://arxiv.org/pdf/2004.06231.pdf
Thank you very much~
Hi, I got one more question.
log-einsum-exp trick
works fascinatingly well, but is there any way to extend it to the Max semiring for inference?
Pytorch does not have optimized CUDA kernel for generalized matrix multiplication other than matmul
. That is, customized kernel implementation is needed. Some useful references: https://github.com/harvardnlp/genbmm (caveat: a little bit slow), https://www.kernel-operations.io/keops/index.html ( I did not use it before but you can have a try), and https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html (you can simply replace tl.dot with tl.max)
Generally speaking, if you are only interested in decoding, you can use autograd to estimate marginals and do a simple CKY using the marginals, namely MBR decoding
described in the paper. If you want labelled tree, you can estimate the labeled span marginals and do a simple CKY. If you really want the argmax tree, you need implement the kernel yourself
A little bit different from your answer, for a given distribution dist
, e.g., labeled CRF, labeled CKY, does the dist.marginal.argmax(dim=-1)
always equal to dist.argmax
? I guess the answer is no?
A little bit different from your answer, for a given distribution
dist
, e.g., labeled CRF, labeled CKY, does thedist.marginal.argmax(dim=-1)
always equal todist.argmax
? I guess the answer is no?
No
Thanks~
Hi, thank you for sharing the source code. I got several questions about the data processing details.
fast_tnpcfg_r1000_nt9000_t4500_curriculum0.yaml
, it turns out to be 400. Is this on purpose? Or just a typo?https://github.com/sustcsonglin/TN-PCFG/blob/7047645f874dcf872ed550d6bcd8d5d2b113d50c/config/fast_tnpcfg_r1000_nt9000_t4500_curriculum0.yaml#L32
About dropping sentences with lengths longer than 40, could you please explain the reason for doing this? To avoid out-of-memory? Or just long sentences hurt performance?
But you do use all sentences in the dev and test splits without dropping, is this correct?