assafbk / DeciMamba

Official Implementation Of The Paper: `DeciMamba: Exploring the Length Extrapolation Potential of Mamba'
MIT License
18 stars 0 forks source link

Potential bug or expected behavior in selective_scan_interface #2

Open JindongJiang opened 2 days ago

JindongJiang commented 2 days ago

Hi, thank you for the great work! I have a question about a potential bug, in the mamba_ssm/ops/selective_scan_interface.py you have the following.

x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)

# 'peak inside' the SSM for delta_t, and then pool w.r.t it
not_decimated = get_non_decimated_indices(delta.detach().clone(), delta_bias.detach().clone(), resp_len, layer_num=layer_num, decimation_config=decimation_config)
x_dbl = x_dbl[not_decimated,:]

Here, the x_dbl is of shape (bl d), but when you perform the decimation, you do the first dimension, x_dbl = x_dbl[not_decimated,:], which is not length l but bl. Is this a expected behavior or a mistake?

Thanks!

JindongJiang commented 2 days ago

Actually, since not_decimated are batched indices, should we use gather instead? For example:

x_dbl = rearrange(x_dbl, "(b l) d -> b d l", b=xz.shape[0])
x_dbl = torch.gather(x_dbl, 2, non_decimated_indices.unsqueeze(1).expand(-1, x_dbl.shape[1], -1))
x_dbl = rearrange(x_dbl, "b d l -> (b l) d")
delta = torch.gather(delta, 2, non_decimated_indices.unsqueeze(1).expand(-1, delta.shape[1], -1))
conv1d_out = torch.gather(conv1d_out, 2, non_decimated_indices.unsqueeze(1).expand(-1, conv1d_out.shape[1], -1))
z = torch.gather(z, 2, non_decimated_indices.unsqueeze(1).expand(-1, z.shape[1], -1))