Human9000 / nd-Mamba2-torch

Only implemented through torch: "bi - mamba2" , "vision- mamba2 -torch". support 1d/2d/3d/nd and support export by jit.script/onnx;
105 stars 2 forks source link

vssd_torch.py #3

Open Muzi010 opened 3 weeks ago

Muzi010 commented 3 weeks ago

y = self.non_casual_linear_attn( rearrange(x, "b l (h p) -> b l h p", p=self.headdim), dt, A, B, C, self.D, H, W ) 运行到这里报错TypeError: 'Tensor' object is not callable

Human9000 commented 3 weeks ago

感谢你的反馈,我这边会去尽快核对,你可以先尝试使用我最新发布的ex_vssd.py文件 该文件已经通过完整的正向传播以及script和onnx导出的测试

Muzi010 commented 3 weeks ago

class tTensor(torch.Tensor): @property def shape(self): return super().shape

shape = super().shape

    # return tuple([int(s) for s in shape])我将这个类进行修改一下就能跑通了,直接返回原来的数据类型,但是我不懂原理,不知道这样对不对。你可以参考一下这里
Human9000 commented 3 weeks ago

经过测试,我并没有成功复现出来你发现的错误,可能是由于我们使用的环境有差异产生的影响,你可以尝试使用ex_vssd.py的版本,在这个版本中,移除了tTensor这个自定义类,因为自定义类的兼容性较差,无法保证稳定运行在任意环境中。 你这种做法是正确的,原因是自定义类里对于shape的返回值数据类型从torch.size修改成了tuple,这和标准的torch.tensor是不一致的,因此无法保证在任意环境下的通用性。

Muzi010 commented 3 weeks ago

好的,谢谢,应该是环境原因。我想问一下,你发布的文件,bi 指的是是双向,ex_vssd.py 中 ex 代表什么?是表示新的修改版么

Human9000 commented 3 weeks ago

首先再次对你的反馈表示感谢,下面是对你提出的反馈的说明: bi是指 Bidirectional,是双向的意思。 ex是指export,是导出的意思。 该项目中的ex版本的代码均支持直接导出成onnx或者jit.script格式,以便于后期部署工作

Muzi010 commented 3 weeks ago

ok,非常感谢。