harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

Mini-batch setting with Semi Markov CRF #110

Closed urchade closed 3 years ago

urchade commented 3 years ago

I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking. The model train well when setting batch size 1.

srush commented 3 years ago

thanks. Can you by any chance provide an example? I will take a look.

urchade commented 3 years ago

The problem also occurs during inference :

import torch, torch_struct
import matplotlib.pyplot as plt

torch.manual_seed(1)

batch, N, C, K = 3, 10, 2, 6

def show_sm(chain):
    plt.imshow(chain.detach().sum(1).sum(-1).transpose(0, 1))

log_potentials = torch.randn(batch, N, K, C, C)

# dist with and withoud mask length (we do not pad the 0th element of the batch)
dist_1 = torch_struct.SemiMarkovCRF(log_potentials)
dist_2 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 1]))
dist_3 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 4]))

# argmax for the 0th index should be the same for every dist since there is no padding on this index
assert torch.allclose(dist_1.argmax[0], dist_2.argmax[0])
assert torch.allclose(dist_1.argmax[0], dist_3.argmax[0])
assert torch.allclose(dist_2.argmax[0], dist_3.argmax[0])
srush commented 3 years ago

oh thanks, this is a useful test (and sounds like a bug)

@da03 we should fix this. Any chance you could take a first look?

da03 commented 3 years ago

@urchade Thanks for pointing this out! It's fixed in PR #114. The issue was due to this line https://github.com/harvardnlp/pytorch-struct/blob/5328ec52263c008d209cfda171de85d51c4d988f/torch_struct/semimarkov.py#L67 not considering different ending positions for sentences of different lengths.

Besides, I also added back another implementation _dp_standard for log partition calculation that's more memory-efficient, which can be used like below:

import torch, torch_struct

torch.manual_seed(1)

batch, N, C, K = 3, 10, 2, 6

log_potentials = torch.randn(batch, N, K, C, C)

dist_1 = torch_struct.SemiMarkov()
dist_2 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 1]))

assert torch.allclose(dist_1._dp_standard(log_potentials, lengths=torch.LongTensor([N+1, 5, 1]))[0], dist_2.partition)
srush commented 3 years ago

Oh wow, impressive @da03 ! This code is really complex.

Long term let make SemiMarkovParallel and SemiMarkovFlat their own classes and let CRF pick which one to use.