zhouhaoyi / Informer2020

The GitHub repository for the paper "Informer" accepted by AAAI 2021.
Apache License 2.0
5.38k stars 1.12k forks source link

希望支持将模型转换为torchscript格式 #60

Closed MaleicAcid closed 3 years ago

MaleicAcid commented 3 years ago

希望能够支持将模型转换为torchscript格式,便于得到更广泛的应用。 我希望能够在Java App中调用训练好的Informer模型, 在了解了JDL库后得知需要先将pth模型转换为torchscript格式 我尝试使用如下代码进行转换

Exp = Exp_Informer
exp = Exp(args) # 使用训练时相同的参数初始化模型
pthfile = './checkpoints/test/checkpoint.pth'

examples = exp.trace() # 为Informer类新增一个方法以便获取forward()所需的参数, 在此例中返回值是一个tuple()
model = exp.model # 获取模型并加载
model.load_state_dict(torch.load(pthfile))

# 尝试推理并转换
traced_script_module = torch.jit.trace(model, examples)
traced_script_module.save("./traced_model.pt")

我得到了如下错误, File "E:\pythonspace\deep_learning\Informer2020\models\attn.py", line 110, in forward U_part = self.factor np.ceil(np.log(L_K)).astype('int').item() # cln(L_k) AttributeError: 'Tensor' object has no attribute 'astype'

https://zhuanlan.zhihu.com/p/146453159 我猜测这是因为使用了numpy 中的np.ceil(), np.log()函数导致的, 我尝试将其替换为torch对应的函数但仍不奏效

# U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # 转换中应尽量避免使用np [https://zhuanlan.zhihu.com/p/146453159] 
U_part = self.factor * torch.ceil(torch.log(L_K)).int() # 尝试替换为torch对应的写法

这样改之后错误变成了这样: torch.jit._trace.TracingCheckError: Tracing failed sanity checks! encountered an exception while running the Python function with test inputs. Exception: log(): argument 'input' (position 1) must be Tensor, not int

希望大佬能指点一下,这里如果想不用np的话应该怎么修改,万分感谢❀❀❀

zhouhaoyi commented 3 years ago

目前我们正在开发往其他平台迁移的工程,可能暂时没有时间来处理这个issue。欢迎你继续探索转换为torchscript的过程。 另外第一个AttributeError可能是自动转换的时候已经进行了np到torch的转换,而torch里面是没有astype方法的,换一个抽象实现吧。 单纯从第二个报错来看你只需要转换log的input为tensor即可。因为目前你使用的torch函数是针对tensor对象使用的。

MaleicAcid commented 3 years ago
file: models/attn.py
class ProbAttention(nn.Module):
    def forward(self, queries, keys, values, attn_mask):
        ......
        # 这两行改成这样
        U_part = self.factor * torch.ceil(torch.log(torch.tensor(L_K).float())).int() # c*ln(L_k)
        u = self.factor * (torch.ceil(torch.log(torch.tensor(L_Q).float())).int()) # c*ln(L_q)

这样能成功导出pt文件,实测能用DJL进行调用。 如果不加那个 .float() , DJL会报一个“log() 不支持long类型”的错

[main] WARN ai.djl.pytorch.jni.LibUtils - No matching cuda flavor for win found: cu111.
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 1
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 2
predict output: 
NDList size: 1
0 : (1, 24, 6) float64

end informer predict use 354 ms
cookieminions commented 3 years ago
file: models/attn.py
class ProbAttention(nn.Module):
    def forward(self, queries, keys, values, attn_mask):
        ......
        # 这两行改成这样
        U_part = self.factor * torch.ceil(torch.log(torch.tensor(L_K).float())).int() # c*ln(L_k)
        u = self.factor * (torch.ceil(torch.log(torch.tensor(L_Q).float())).int()) # c*ln(L_q)

这样能成功导出pt文件,实测能用DJL进行调用。 如果不加那个 .float() , DJL会报一个“log() 不支持long类型”的错

[main] WARN ai.djl.pytorch.jni.LibUtils - No matching cuda flavor for win found: cu111.
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 1
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 2
predict output: 
NDList size: 1
0 : (1, 24, 6) float64

end informer predict use 354 ms

Great Job!

我们将会在之后的更新中加入这部分代码,非常感谢!

leepengcheng commented 3 years ago
file: models/attn.py
class ProbAttention(nn.Module):
    def forward(self, queries, keys, values, attn_mask):
        ......
        # 这两行改成这样
        U_part = self.factor * torch.ceil(torch.log(torch.tensor(L_K).float())).int() # c*ln(L_k)
        u = self.factor * (torch.ceil(torch.log(torch.tensor(L_Q).float())).int()) # c*ln(L_q)

这样能成功导出pt文件,实测能用DJL进行调用。 如果不加那个 .float() , DJL会报一个“log() 不支持long类型”的错

[main] WARN ai.djl.pytorch.jni.LibUtils - No matching cuda flavor for win found: cu111.
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 1
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 2
predict output: 
NDList size: 1
0 : (1, 24, 6) float64

end informer predict use 354 ms

貌似只该这2行不行,请教下还有别的要注意的吗