Open JindongJiang opened 2 days ago
Actually, since not_decimated are batched indices, should we use gather
instead? For example:
x_dbl = rearrange(x_dbl, "(b l) d -> b d l", b=xz.shape[0])
x_dbl = torch.gather(x_dbl, 2, non_decimated_indices.unsqueeze(1).expand(-1, x_dbl.shape[1], -1))
x_dbl = rearrange(x_dbl, "b d l -> (b l) d")
delta = torch.gather(delta, 2, non_decimated_indices.unsqueeze(1).expand(-1, delta.shape[1], -1))
conv1d_out = torch.gather(conv1d_out, 2, non_decimated_indices.unsqueeze(1).expand(-1, conv1d_out.shape[1], -1))
z = torch.gather(z, 2, non_decimated_indices.unsqueeze(1).expand(-1, z.shape[1], -1))
Hi, thank you for the great work! I have a question about a potential bug, in the
mamba_ssm/ops/selective_scan_interface.py
you have the following.Here, the x_dbl is of shape
(bl d)
, but when you perform the decimation, you do the first dimension,x_dbl = x_dbl[not_decimated,:]
, which is not lengthl
butbl
. Is this a expected behavior or a mistake?Thanks!