alibaba / TinyNeuralNetwork

TinyNeuralNetwork is an efficient and easy-to-use deep learning model compression framework.
MIT License
760 stars 116 forks source link

关于模型中使用了nn.Parameter()会导致转TFlite会失败的问题 #330

Closed Bun-TianYi closed 5 months ago

Bun-TianYi commented 5 months ago

当我在尝试将我的模型转换为TFlite模型时出现了如下报错:

export.py 95 <module>

export.py 33 ckpt2tflite
converter.convert()

base.py 480 convert
self.init_operations()

base.py 443 init_operations
converter.parse(node, attrs, args, self.common_graph)

aten.py 1187 parse
assert False, "other should have type int, float, tensor in aten::mul(input, other)"

AssertionError:
other should have type int, float, tensor in aten::mul(input, other)

调试追踪后发现在一些层中,我使用nn.Parameter()方法将一些自定义的需要被训练的参数进行了注册,此时这些参数虽然还是张量,但被包裹在了Parameter containing中,这导致项目源码中进行类型检查的函数无法正确识别此类参数,以下时项目中转换失败的关键代码:


class ATenAddOperator(ATenAddSchema):
    def parse(self, node, attrs, args, graph_converter):
        super().parse(node, attrs, args, graph_converter)

        self.run(node)

        other = self.input_tensors[1]
        alpha = self.input_tensors[-1]
        assert alpha == 1, "Only alpha == 1 is supported"

        if type(other) in (int, float, bool):
            self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
        elif type(other) != torch.Tensor:
            assert False, "other should have type int, float, tensor in aten::add(input, other)"

        self.elementwise_binary(tfl.AddOperator, graph_converter, True)

该函数检查了变量是否为int,float,bool以及torch.tensor,不在上述数据类型的变量都会引发一个assert error,但项目管理员似乎没有考虑到Parameter containing类型的变量是torch.tensor的子类型,能够进行正常的张量运算,或者调用.data方法也能够将Parameter containing类型的变量变成torch.tensor的变量(当然我不确定这样是否会引发其他未知错误导致转换的精度丢失),我姑且在源码中对两个算子进行了一些改动,成功让模型转换成功:

class ATenAddOperator(ATenAddSchema):
    def parse(self, node, attrs, args, graph_converter):
        super().parse(node, attrs, args, graph_converter)

        self.run(node)

        other = self.input_tensors[1]
        alpha = self.input_tensors[-1]
        assert alpha == 1, "Only alpha == 1 is supported"

        if isinstance(other, torch.nn.Parameter):
            self.input_tensors[1] = self.input_tensors[1].data
            other = self.input_tensors[1]

        if type(other) in (int, float, bool):
            self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
        elif type(other) != torch.Tensor:
            assert False, "other should have type int, float, tensor in aten::add(input, other)"

        self.elementwise_binary(tfl.AddOperator, graph_converter, True)

class ATenMulOperator(ATenMulSchema):
    def parse(self, node, attrs, args, graph_converter):
        super().parse(node, attrs, args, graph_converter)

        self.run(node)

        other = self.input_tensors[1]

        if isinstance(other, torch.nn.Parameter):
            self.input_tensors[1] = self.input_tensors[1].data
            other = self.input_tensors[1]

        if type(other) in (int, float):
            self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)

        elif type(other) != torch.Tensor:
            assert False, "other should have type int, float, tensor in aten::mul(input, other)"

        self.elementwise_binary(tfl.MulOperator, graph_converter, True)

项目人员可以看看这样的改动是否会引起精度丢失的问题,虽然模型转换成功,也能够正常推理使用,但我对这样简单的改动抱有一定疑虑

peterjc123 commented 5 months ago

@Bun-TianYi 直接改成if not isinstance(other, torch.Tensor)就好了吧,这样也不破坏变量的传播关系

Bun-TianYi commented 5 months ago

@Bun-TianYi 直接改成if not isinstance(other, torch.Tensor)就好了吧,这样也不破坏变量的传播关系

我认为这样会引发一些恶性bug,因为在一些特殊需求的模型中是有人将常量也作为模型的一层输出使用,所以源码中的 if type(other) in (int, float): self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype) 这一句是有必要的

peterjc123 commented 5 months ago

@Bun-TianYi 直接改成if not isinstance(other, torch.Tensor)就好了吧,这样也不破坏变量的传播关系

我认为这样会引发一些恶性bug,因为在一些特殊需求的模型中是有人将常量也作为模型的一层输出使用,所以源码中的 if type(other) in (int, float): self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype) 这一句是有必要的

我的意思是把

        elif type(other) != torch.Tensor:
            assert False, "other should have type int, float, tensor in aten::mul(input, other)"

改成

        elif not isinstance(other, torch.Tensor):
            assert False, "other should have type int, float, tensor in aten::mul(input, other)"
Bun-TianYi commented 5 months ago

@Bun-TianYi 直接改成if not isinstance(other, torch.Tensor)就好了吧,这样也不破坏变量的传播关系

我认为这样会引发一些恶性bug,因为在一些特殊需求的模型中是有人将常量也作为模型的一层输出使用,所以源码中的 if type(other) in (int, float): self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype) 这一句是有必要的

我的意思是把

        elif type(other) != torch.Tensor:
            assert False, "other should have type int, float, tensor in aten::mul(input, other)"

改成

        elif not isinstance(other, torch.Tensor):
            assert False, "other should have type int, float, tensor in aten::mul(input, other)"

哦哦!我理解错了,确实,这样修改更优雅了!

peterjc123 commented 5 months ago

@Bun-TianYi #331 应该能修复这个问题