kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
952 stars 152 forks source link

Freezing certain transition scores while training #120

Open aberthel opened 3 months ago

aberthel commented 3 months ago

I want to build a CRF that represents a specific state graph where I know that transitions between certain states are impossible (e.g. I have states A, B, and C, and I know that the probability of the transition A->C is 0). I want to explicitly set the probability of those impossible transitions to zero, freeze those values, and then train the rest of the model as normal. I know that I could manipulate the CRF transitions tensor outside of the training loop, and I know that I could freeze the entire parameter by pytorch conventions, but is there a way to freeze particular cells within the transition matrix?

kmkurn commented 3 months ago

I'm not sure, but I don't think so. This question is perhaps better posted in PyTorch's forum.

Anyway, since the transition tensor is initialised randomly, I'm unsure why freezing these invalid transition cells is desirable. Moreover, if these transitions are invalid then surely they don't occur in the training data. Thus, these cells will never be updated, effectively equivalent to freezing them. Am I missing something?