harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

Conditional and joint span probabilities in TreeCRF #107

Open rubencart opened 3 years ago

rubencart commented 3 years ago

I want to use the TreeCRF class to learn latent tree distributions for constituency trees for sentences. I noticed you can easily obtain the text span marginals with .marginals. However, I am interested in computing more probabilities in the tree distribution, like the conditional probability that one span occurs in the tree, given that another one occurs, or the joint probability of two spans. Is there an easy way to compute these probabilities from the marginals? Or using different torch-struct functionality?

A 'dirty' trick for the conditional probability could be to compute the marginals again, with the potential of the span you want to condition on set to a very high value? The new marginals would then actually be conditional probabilities? But that requires running the parsing algorithm once per condition, which ideally I would like to avoid.

srush commented 3 years ago

What a great question...

If there are two specific spans that you need, then your "dirty trick" is the right way to do it.

If you want to do it efficiently for any pair of spans, then there are some fun auto-diff tricks you can use. If I remember correctly, the hessian of the log-partition (with respect to the log-potentials) will give you the joint of all pairs of spans. I don't think this is currently implemented in the library, but wouldn't be that hard to add.

maybe look at https://pytorch.org/docs/stable/generated/torch.autograd.functional.hessian.html

rubencart commented 3 years ago

Thank you :-). I am more interested in an efficient way for any pair of spans. The hessian trick sounds interesting, could you perhaps point me to a reference that explains this relation? And/or relations between other auto-diff tricks and other probabilities (like for the conditionals)?

srush commented 3 years ago

I think this is a nice reference for bayes nets https://dl.acm.org/doi/pdf/10.1145/765568.765570

image

Alternatively you can think of CRF as exponential families and therefore the log-partition generates moments:

https://www.cs.cmu.edu/~epxing/Class/10708-14/scribe_notes/scribe_note_lecture6.pdf

I can't find a nice reference though to explain the hessian, but you can derive it from differentiating the log-partition twice \log \sum_i \exp (l_i) with respect to l_j and l_k (you should end up with a term that has the sum over all structure with part j and k in the numerator and the partition function in the denominator.)

If you are feeling brave it is also in this paper.

https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00391/102843/Efficient-Computation-of-Expectations-under

rubencart commented 3 years ago

That's super interesting, thank you for your help!

I am going to try to use this for my project, but would you like me to make an attempt for a PR to add it to this library as well? In that case, do you have any pointers on where to add it?

Some last questions :-) : The derivations in the last paper are for distributions of spanning trees in graphs, defined by edge weights, but I think they remain valid for distributions of constituency trees, if you just replace all edge weights by span weights, right? (Not the efficient calculations using the Matrix-Tree Theorem though, but the relations between partial derivatives and expectations).

Also, am I right to deduce that 3rd order partial derivatives of the partition function then give joint probabilities of 3 edges (assuming the 2nd order derivatives are differentiable)? Unless I'm making a mistake, this is easy to prove in the same way as proving the relation between the first and second order derivatives and the marginals and joint of 2 edges resp.?

srush commented 3 years ago

Yes, would love a PR. I think you can do this in the StructDistribution class https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/distributions.py. Should in theory work for all the distributions in the library.

The main thing I would suggest though is that everything in torch-struct needs to be general (not tree or sequence specific), and tested. The testing part is done by enumerating all trees up to a certain size and testing that the pairwise marginals are the same.

However, do not feel compelled. Would be happy just to hear if this works for your use case.