sustcsonglin / TN-PCFG

source code of NAACL2021 "PCFGs Can Do Better: Inducing Probabilistic Context-Free Grammars with Many Symbols“ and ACL2021 main conference "Neural Bilexicalized PCFG Induction"
45 stars 6 forks source link

Questions about experiment details #3

Closed speedcell4 closed 1 year ago

speedcell4 commented 1 year ago

Hi, thank you for sharing the source code. I got several questions about the data processing details.

  1. About the train datasets' sentence length. I found that you set the maximal length of training sentences to be 40 in most of your settings, but only in fast_tnpcfg_r1000_nt9000_t4500_curriculum0.yaml, it turns out to be 400. Is this on purpose? Or just a typo?

https://github.com/sustcsonglin/TN-PCFG/blob/7047645f874dcf872ed550d6bcd8d5d2b113d50c/config/fast_tnpcfg_r1000_nt9000_t4500_curriculum0.yaml#L32

  1. About dropping sentences with lengths longer than 40, could you please explain the reason for doing this? To avoid out-of-memory? Or just long sentences hurt performance?

  2. But you do use all sentences in the dev and test splits without dropping, is this correct?

sustcsonglin commented 1 year ago
  1. maybe it is a typo :)
  2. We just follow the previous practice of [Kim et al 2019]. They train on sentences of length up to 40 for computational reasons because their implementation is not optimized. I do not believe long sentences will hurt the performance and now it is free to train on arbitrary length sentences, as current implementation is highly optimized with Triton
  3. Yes. we evaluate on all valid/test sentences
speedcell4 commented 1 year ago

Thank you for your quick response.

I think now it is free to train on arbitrary length sentences, as current implementation is optimized.

Could you please briefly describe what optimization has been applied to your current implementation? I noticed you re-implemented the inside function with trion. Is there anything else important?

sustcsonglin commented 1 year ago

the main speedup comes from the log-einsum-exp trick. You can also refer to my pure pytorch implementation to get the idea of this trick. In my triton implementation, I do kernel fusion to reduce IO cost in the spirit of FlashAttention.

speedcell4 commented 1 year ago

Sorry, I don't get it. What do you mean by log-einsum-exp trick? is that the log-sum-exp trick? To my knowledge, this is for avoiding overflow, not for accelerating. Could you please describe more details on this?

sustcsonglin commented 1 year ago

log-einsum-exp trick introduced from Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits https://arxiv.org/pdf/2004.06231.pdf

speedcell4 commented 1 year ago

Thank you very much~

speedcell4 commented 1 year ago

Hi, I got one more question.

log-einsum-exp trick works fascinatingly well, but is there any way to extend it to the Max semiring for inference?

sustcsonglin commented 1 year ago

Pytorch does not have optimized CUDA kernel for generalized matrix multiplication other than matmul. That is, customized kernel implementation is needed. Some useful references: https://github.com/harvardnlp/genbmm (caveat: a little bit slow), https://www.kernel-operations.io/keops/index.html ( I did not use it before but you can have a try), and https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html (you can simply replace tl.dot with tl.max)

Generally speaking, if you are only interested in decoding, you can use autograd to estimate marginals and do a simple CKY using the marginals, namely MBR decoding described in the paper. If you want labelled tree, you can estimate the labeled span marginals and do a simple CKY. If you really want the argmax tree, you need implement the kernel yourself

speedcell4 commented 1 year ago

A little bit different from your answer, for a given distribution dist, e.g., labeled CRF, labeled CKY, does the dist.marginal.argmax(dim=-1) always equal to dist.argmax? I guess the answer is no?

sustcsonglin commented 1 year ago

A little bit different from your answer, for a given distribution dist, e.g., labeled CRF, labeled CKY, does the dist.marginal.argmax(dim=-1) always equal to dist.argmax? I guess the answer is no?

No

speedcell4 commented 1 year ago

Thanks~