harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

DependencyCRF marginals possible error #64

Closed kmkurn closed 4 years ago

kmkurn commented 4 years ago

Hi, while working on #63, I noticed that DependencyCRF marginals may have numerical errors:

>>> crf = DependencyCRF(torch.zeros(1,2,2))
>>> print(crf.partition.exp().item())
3.0
>>> crf.marginals.exp()
tensor([[[1.9477, 1.3956],
         [1.3956, 1.9477]]], grad_fn=<ExpBackward>)

crf.partition is correct; there are 3 trees. Since all edges have weight 1, I'd expect the marginals to be (very close to) 2 on diagonals, and 1 on off-diagonals. But they're not. Is this an error or am I misunderstanding something?

srush commented 4 years ago

I don't think there is a bug.

import torch
crf = torch_struct.DependencyCRF(torch.zeros(1,2,2))
print(crf.partition.exp().item())
crf.marginals
3.0
tensor([[[0.6667, 0.3333],
         [0.3333, 0.6667]]], grad_fn=<SqueezeBackward1>)

Marginals are just probabilities (not in log space). p(arc | x)

I should have called the partition -> log_partition for accuracy. I will update that.

kmkurn commented 4 years ago

I see. Thanks for the clarification!