rikeda71 / TorchCRF

An Inplementation of CRF (Conditional Random Fields) in PyTorch 1.0
MIT License
135 stars 11 forks source link

There is an error in the annotation #7

Closed YangHaha11514 closed 4 years ago

YangHaha11514 commented 4 years ago
    def forward(
        self, h: FloatTensor, labels: LongTensor, mask: BoolTensor
    ) -> FloatTensor:
        """
        :param h: hidden matrix (seq_len, batch_size, num_labels)
        :param labels: answer labels of each sequence
                       in mini batch (seq_len, batch_size)
        :param mask: mask tensor of each sequence
                     in mini batch (seq_len, batch_size)
        :return: The log-likelihood (batch_size)
        """

In the annotation of this function, the shape of param h\labels\mask should be (batch_size,seq_len,)

rikeda71 commented 4 years ago

@YangHaha11514 Thank you for discovering annotation misses. If you don't mind, you can fix its and make pull requests ?