pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.58k stars 22.78k forks source link

[ONNX] Input node deleted when converting a Conditional random field model #80016

Closed QiusongYang closed 7 months ago

QiusongYang commented 2 years ago

🐛 Describe the bug

CRF model py:

import torch
import torch.nn as nn
from typing import List, Optional

class CRF(nn.Module):
    """Conditional random field.
    This module implements a conditional random field [LMP01]_. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has `~CRF.decode` method which finds
    the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
    Args:
        num_tags: Number of tags.
        batch_first: Whether the first dimension corresponds to the size of a minibatch.
    Attributes:
        start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
            ``(num_tags,)``.
        end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
            ``(num_tags,)``.
        transitions (`~torch.nn.Parameter`): Transition score tensor of size
            ``(num_tags, num_tags)``.
    .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
       "Conditional random fields: Probabilistic models for segmenting and
       labeling sequence data". *Proc. 18th International Conf. on Machine
       Learning*. Morgan Kaufmann. pp. 282–289.
    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.
        The parameters will be initialized randomly from a uniform distribution
        between -0.1 and 0.1.
        """
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(self,
                emissions: torch.Tensor,
                mask: Optional[torch.ByteTensor] = None,
                tags: torch.LongTensor = None,
                reduction: str = 'mean',
                nbest: Optional[int] = None,
                pad_tag: Optional[int] = None) -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.
        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            tags (`~torch.LongTensor`): Sequence of tags tensor of size
                ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            reduction: Specifies  the reduction to apply to the output:
                ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
                ``sum``: the output will be summed over batches. ``mean``: the output will be
                averaged over batches. ``token_mean``: the output will be averaged over tokens.
            nbest (`int`): Number of most probable paths for each sequence
            pad_tag (`int`): Tag at padded positions. Often input varies in length and
                the length will be padded to the maximum length in the batch. Tags at
                the padded positions will be assigned with a padding tag, i.e. `pad_tag`
        Returns:
            `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
            reduction is ``none``, ``()`` otherwise.
        """
        if tags is not None:
            # training
            if reduction not in ('none', 'sum', 'mean', 'token_mean'):
                raise ValueError(f'invalid reduction: {reduction}')
            if mask is None:
                mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)
            if mask.dtype != torch.uint8:
                mask = mask.byte()
            self._validate(emissions, tags=tags, mask=mask)

            if self.batch_first:
                emissions = emissions.transpose(0, 1)
                tags = tags.transpose(0, 1)
                mask = mask.transpose(0, 1)

            # shape: (batch_size,)
            numerator = self._compute_score(emissions, tags, mask)
            # shape: (batch_size,)
            denominator = self._compute_normalizer(emissions, mask)
            # shape: (batch_size,)
            llh = numerator - denominator

            crf_loss = None
            if reduction == 'none':
                crf_loss = llh
            elif reduction == 'sum':
                crf_loss = llh.sum()
            elif reduction == 'mean':
                crf_loss = llh.mean()
            else:
                crf_loss = llh.sum() / mask.float().sum()
            return crf_loss
        else:
            # predict
            predict_paths = self.decode(emissions, mask, nbest=nbest, pad_tag=pad_tag)
            return predict_paths

    def decode(self,
               emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None,
               nbest: Optional[int] = None,
               pad_tag: Optional[int] = None) -> torch.Tensor:
        """Find the most likely tag sequence using Viterbi algorithm.
        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            nbest (`int`): Number of most probable paths for each sequence
            pad_tag (`int`): Tag at padded positions. Often input varies in length and
                the length will be padded to the maximum length in the batch. Tags at
                the padded positions will be assigned with a padding tag, i.e. `pad_tag`
        Returns:
            A PyTorch tensor of the best tag sequence for each batch of shape
            (nbest, batch_size, seq_length)
        """
        if nbest is None:
            nbest = 1
        if mask is None:
            mask = torch.ones(emissions.shape[:2], dtype=torch.uint8,
                              device=emissions.device)
        if mask.dtype != torch.uint8:
            mask = mask.byte()
        self._validate(emissions, mask=mask)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        predict_paths = None
        if nbest == 1:
            predict_paths = self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)
        else:
            predict_paths = self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
        return predict_paths

    def _validate(self, emissions: torch.Tensor,
                  tags: Optional[torch.LongTensor] = None,
                  mask: Optional[torch.ByteTensor] = None) -> None:
        input_shape = emissions.shape
        if len(input_shape) != 3:
            raise ValueError(f'emissions must have dimension of 3, got {len(input_shape)}')
        if input_shape[2] != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {input_shape[2]}')

        if tags is not None:
            tags_shape = tags.shape
            if input_shape[0:2] != tags_shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(input_shape[:2])} and {tuple(tags_shape)}')

        if mask is not None:
            mask_shape = mask.shape
            if input_shape[:2] != mask_shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(input_shape[:2])} and {tuple(mask_shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    # @torch.jit.script
    def _compute_score(self, emissions: torch.Tensor,
                       tags: torch.LongTensor,
                       mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    # @torch.jit.script
    def _compute_normalizer(self, emissions: torch.Tensor,
                            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    # @torch.jit.script
    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor,
                        pad_tag: Optional[int] = None) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        # return: (batch_size, seq_length)
        if pad_tag is None:
            pad_tag = 0

        device = emissions.device
        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history_idx = torch.zeros((seq_length, batch_size, self.num_tags),
                                  dtype=torch.long, device=device)
        oor_idx = torch.zeros((batch_size, self.num_tags),
                              dtype=torch.long, device=device)
        oor_tag = torch.full((seq_length, batch_size), pad_tag,
                             dtype=torch.long, device=device)

        # - score is a tensor of size (batch_size, num_tags) where for every batch,
        #   value at column j stores the score of the best tag sequence so far that ends
        #   with tag j
        # - history_idx saves where the best tags candidate transitioned from; this is used
        #   when we trace back the best tag sequence
        # - oor_idx saves the best tags candidate transitioned from at the positions
        #   where mask is 0, i.e. out of range (oor)

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(-1).bool(), next_score, score)
            indices = torch.where(mask[i].unsqueeze(-1).bool(), indices, oor_idx)
            history_idx[i - 1] = indices

        # End transition score
        # shape: (batch_size, num_tags)
        end_score = score + self.end_transitions
        _, end_tag = end_score.max(dim=1)

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1

        # insert the best tag at each sequence end (last position with mask == 1)
        history_idx = history_idx.transpose(1, 0).contiguous()
        history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),
                             end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))
        history_idx = history_idx.transpose(1, 0).contiguous()

        # The most probable path for each sequence
        best_tags_arr = torch.zeros((seq_length, batch_size),
                                    dtype=torch.long, device=device)
        best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
        for idx in range(seq_length - 1, -1, -1):
            best_tags = torch.gather(history_idx[idx], 1, best_tags)
            best_tags_arr[idx] = best_tags.data.view(batch_size)

        return torch.where(mask.bool(), best_tags_arr, oor_tag).transpose(0, 1)

    # @torch.jit.script
    def _viterbi_decode_nbest(self, emissions: torch.FloatTensor,
                              mask: torch.ByteTensor,
                              nbest: int,
                              pad_tag: Optional[int] = None) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        # return: (nbest, batch_size, seq_length)
        if pad_tag is None:
            pad_tag = 0

        device = emissions.device
        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history_idx = torch.zeros((seq_length, batch_size, self.num_tags, nbest),
                                  dtype=torch.long, device=device)
        oor_idx = torch.zeros((batch_size, self.num_tags, nbest),
                              dtype=torch.long, device=device)
        oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag,
                             dtype=torch.long, device=device)

        # + score is a tensor of size (batch_size, num_tags) where for every batch,
        #   value at column j stores the score of the best tag sequence so far that ends
        #   with tag j
        # + history_idx saves where the best tags candidate transitioned from; this is used
        #   when we trace back the best tag sequence
        # - oor_idx saves the best tags candidate transitioned from at the positions
        #   where mask is 0, i.e. out of range (oor)

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            if i == 1:
                broadcast_score = score.unsqueeze(-1)
                broadcast_emission = emissions[i].unsqueeze(1)
                # shape: (batch_size, num_tags, num_tags)
                next_score = broadcast_score + self.transitions + broadcast_emission
            else:
                broadcast_score = score.unsqueeze(-1)
                broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)
                # shape: (batch_size, num_tags, nbest, num_tags)
                next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission

            # Find the top `nbest` maximum score over all possible current tag
            # shape: (batch_size, nbest, num_tags)
            next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)

            if i == 1:
                score = score.unsqueeze(-1).expand(-1, -1, nbest)
                indices = indices * nbest

            # convert to shape: (batch_size, num_tags, nbest)
            next_score = next_score.transpose(2, 1)
            indices = indices.transpose(2, 1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags, nbest)
            score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1).bool(), next_score, score)
            indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1).bool(), indices, oor_idx)
            history_idx[i - 1] = indices

        # End transition score shape: (batch_size, num_tags, nbest)
        end_score = score + self.end_transitions.unsqueeze(-1)
        _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1

        # insert the best tag at each sequence end (last position with mask == 1)
        history_idx = history_idx.transpose(1, 0).contiguous()
        history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
                             end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
        history_idx = history_idx.transpose(1, 0).contiguous()

        # The most probable path for each sequence
        best_tags_arr = torch.zeros((seq_length, batch_size, nbest),
                                    dtype=torch.long, device=device)
        best_tags = torch.arange(nbest, dtype=torch.long, device=device) \
                         .view(1, -1).expand(batch_size, -1)
        for idx in range(seq_length - 1, -1, -1):
            best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags)
            best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest

        return torch.where(mask.unsqueeze(-1).bool(), best_tags_arr, oor_tag).permute(2, 1, 0)

Convert script:

import unittest
import numpy as np
import torch
import onnxruntime
from src.models.common.crf import CRF

class CrfModelTest(unittest.TestCase):
    def test_crf_to_onnx_model(self):
        crf_model = CRF(num_tags=25, batch_first=True)
        torch_model_path = "/root/NLP_Meta/data/resume_parsing_new/segmentation/20220601_1860/pytorch_crf_model.bin"
        # crf_model.load_state_dict(torch.load(torch_model_path, map_location="cpu"))
        dummy_input = torch.randn([1, 64, 25])
        dummy_mask = torch.ones([1, 64], dtype=torch.uint8)
        input_names = ["emissions", "mask"]
        output_names = ["predict_paths"]
        dynamic_axes = {
            "emissions": {0: "batch", 1: "sentence", 2: "hidden_size"},
            "mask": {0: "batch", 1: "sentence"},
            "predict_paths": {0: "path_num", 1: "batch", 2: "sentence"}
        }
        crf_onnx_model_path = "/root/NLP_Meta/data/resume_parsing_new/segmentation/onnx/20220601_1860_segment_crf.onnx"
        torch.onnx.export(crf_model, ({"emissions": dummy_input, "mask": dummy_mask}, ), crf_onnx_model_path,
                          verbose=True,
                          opset_version=12,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes,
                          )
        print("Convert to onnx model complete.")

    def test_crf_onnx_run(self):
        crf_onnx_model_path = "/root/NLP_Meta/data/resume_parsing_new/segmentation/onnx/20220601_1860_segment_crf.onnx"
        sess = onnxruntime.InferenceSession(crf_onnx_model_path, providers=["CPUExecutionProvider"])
        dummy_input = np.random.randn(1, 64, 25).astype(np.float32)
        dummy_mask = np.ones((1, 64), dtype=np.float32)
        outputs = sess.run(output_names=["predict_paths"], input_feed={"emissions": dummy_input})
        print(outputs[0], outputs[0].shape)

Part graph:

graph(%mask : Byte(*, *, strides=[64, 1], requires_grad=0, device=cpu),
      %1240 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1241 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1242 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1243 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1244 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1245 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1246 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1247 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1248 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1249 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1250 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1251 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1252 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1253 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1254 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1255 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1256 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1257 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1258 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1259 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %1260 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %1261 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),

Conversion success only have 'mask' input node, don't have 'emissions' input node?

Versions

torch version: 1.10.0 onnx version: 1.11.0

titaiwangms commented 2 years ago

Hi @dagitses ,

Is there anyway we can access your model to repro?

QiusongYang commented 2 years ago

Hi @dagitses ,

Is there anyway we can access your model to repro?

Using the original randomly initialized parameter has the same effect.

MrRace commented 1 year ago

CRF can not been convert to ONNX ????? Same error!

Nayahei commented 1 year ago

Same problem! Did you soloved it? Help me please! thanks @QiusongYang

QiusongYang commented 1 year ago

Same problem! Did you soloved it? Help me please! thanks @QiusongYang

There has beed solution yet. now crf model run on CPU

Nayahei commented 1 year ago

Same problem! Did you soloved it? Help me please! thanks @QiusongYang

There has beed solution yet. now crf model run on CPU

thanks

thiagocrepaldi commented 7 months ago

Closing as there is no repro

thiagocrepaldi commented 7 months ago

torch.onnx.export is in maintenance mode and we don't plan to add new operators/features or fix complex issues.

Please try the new ONNX exporter and reopen this issue with a full repro if it also doesn't work for you: quick torch.onnx.dynamo_export API tutorial