Open chen-c-alt opened 2 months ago
修改: 1.在源文件261行前加入: while pad_size > in_f : x = torch.cat([x, x[..., :]], dim=-1) pad_size-=in_f 这是由于没考虑pad_size > in_f得情况导致。
2.在源文件293行代码替换为: if out_x.numel() == 0: out_x = out_x.view(x.shape[:-1], out_f) else : out_x = out_x.view(x.shape[:-1], -1)[..., :out_f] 这是由于未考虑 out_x 中element=0得情况。
感谢您的反馈,不过您提到的情况都比较特殊。
pad_size>in_f
r
in_f
修改: 1.在源文件261行前加入: while pad_size > in_f : x = torch.cat([x, x[..., :]], dim=-1) pad_size-=in_f 这是由于没考虑pad_size > in_f得情况导致。
2.在源文件293行代码替换为: if out_x.numel() == 0: out_x = out_x.view(x.shape[:-1], out_f) else : out_x = out_x.view(x.shape[:-1], -1)[..., :out_f] 这是由于未考虑 out_x 中element=0得情况。