fastnlp / fastNLP

fastNLP: A Modularized and Extensible NLP Framework. Currently still in incubation.
https://gitee.com/fastnlp/fastNLP
Apache License 2.0
3.05k stars 451 forks source link

pth模型转换到pt模型时,在crf模块报错 #371

Open wanilyer opened 2 years ago

wanilyer commented 2 years ago

疑似使用crf模块会遇到bug,希望大神帮忙看看,应该怎么解决,谢谢!

Describe the bug

描述

torch.save(model, model_path)

报错信息:

Traceback (most recent call last):
  File "D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py", line 154, in <module>
    export_torchscript(model_path=model_path, model_file=model_file)
  File "D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py", line 125, in export_torchscript
    traced_script_module = torch.jit.trace(model, args)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 875, in trace
    check_tolerance, _force_outplace, _module_class)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 1037, in trace_module
    check_tolerance, _force_outplace, True, _module_class)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\autograd\grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 675, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
    Node:
        %393 : Tensor = prim::Constant[value=<Tensor>]() # C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\fastNLP\modules\decoder\crf.py:295:0
    Source Location:
        C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\fastNLP\modules\decoder\crf.py(295): viterbi_decode
        D:\deeplearning-NLP\flat_lattice_transformer\V1\models.py(524): forward
        C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\nn\modules\module.py(534): _slow_forward
        C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\nn\modules\module.py(548): __call__
        C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py(1027): trace_module
        C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py(875): trace
        D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py(125): export_torchscript
        D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py(154): <module>
    Comparison exception:   Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.
yhcc commented 2 years ago

你使用的是哪个版本的fastNLP?是github上的吗,还是直接pip install fastNLP安装的?

yhcc commented 2 years ago

根据报错信息,我怀疑可能是下面这行有问题https://github.com/fastnlp/fastNLP/blob/b127963f213226dc796720193965d86dface07d5/fastNLP/modules/decoder/crf.py#L307 修改成, flip_mask = torch.logical_not(mask)应该就可以了。这个错误的根源应该类似于https://github.com/pytorch/pytorch/issues/33692 这个,就是booltensor不支持一些运算,导致torchscript在检查转换前和转换后的tensor的时候,会出现结果对不上的问题。

wanilyer commented 2 years ago

感谢您的回复

我用都是0.5.0的版本,直接pip install fastNLP安装的

报错的应该是下面这行 https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L295

yhcc commented 2 years ago

soga,那把这一行修改成mask = mask.transpose(0, 1).data.to(torch.bool)应该就可以。就是所有的bool类型的数据都不要让它做任何比较或者运算。

wanilyer commented 2 years ago

好像不行,还是报同样的错哈。

另外还有下面这行,会报错

https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/embeddings/embedding.py#L208

TypeError: tuple expected at most 1 arguments, got 2

我改成 return torch.Size((self.num_embedding, self._embed_size)) 这样就没问题了

wanilyer commented 2 years ago

我把这一行注释了之后,后面又发现了两个地方也有这个问题,我没在往下注释了

https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L314-L315

https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L323

yhcc commented 2 years ago

确实,感觉fastNLP这些代码应该不太适合转成jit,有太多的逻辑判断了,以及cpu操作了。应该需要你自己对照着稍微改一下。因为jit的话,好像出现constant就不行,但是fastNLP中又大量使用了constant来表示一些数字。