AmeenAli / HiddenMambaAttn

Official PyTorch Implementation of "The Hidden Attention of Mamba Models"
204 stars 12 forks source link

Type Error #2

Closed sivaji123256 closed 9 months ago

sivaji123256 commented 9 months ago

Hi @AmeenAli ,Thanks for the great work.I was trying to replicate the jupyter notebook.I foolwed the same installation procedure you mentioned.But,I was facing following issue.Can you sugget me how I can fix it?Thanks in advance.

TypeError Traceback (most recent call last) Cell In[5], line 3 1 image = transform_for_eval('./images/1.jpg').unsqueeze(0).cuda() 2 raw_image = Image.open('./images/1.jpg') ----> 3 map_raw_atten, logits = generate_raw_attn(model, image) 4 map_mambaattr, = generate_mamba_attr(model, image) 5 maprollout, = generate_rollout(model, image)

File ~/vision_mamba/HiddenMambaAttn/vim/xai_utils.py:21, in generate_raw_attn(model, image, start_layer) 19 def generate_raw_attn(model, image, start_layer=15): 20 image.requiresgrad() ---> 21 logits = model(image) 22 all_layer_attentions = [] 23 cls_pos = 98

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/vision_mamba/HiddenMambaAttn/vim/models_mamba.py:549, in VisionMamba.forward(self, x, return_features, inference_params, if_random_cls_token_position, if_random_token_rank) 548 def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False): --> 549 x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank) 550 if return_features: 551 return x

File ~/vision_mamba/HiddenMambaAttn/vim/models_mamba.py:486, in VisionMamba.forward_features(self, x, inference_params, if_random_cls_token_position, if_random_token_rank) 483 if residual is not None: 484 residual = residual.flip([1]) --> 486 hidden_states, residual = layer( 487 hidden_states, residual, inference_params=inference_params 488 ) 489 else: 490 # get two layers in a single for-loop 491 for i in range(len(self.layers) // 2):

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/vision_mamba/HiddenMambaAttn/vim/models_mamba.py:141, in Block.forward(self, hidden_states, residual, inference_params) 131 else: 132 hidden_states, residual = fused_add_norm_fn( 133 self.drop_path(hidden_states), 134 self.norm.weight, (...) 139 eps=self.norm.eps, 140 )
--> 141 hidden_states = self.mixer(hidden_states, inference_params=inference_params) 142 residual.register_hook(self.save_gradients) 143 return hidden_states, residual

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/vision_mamba/HiddenMambaAttn/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py:216, in Mamba.forward(self, hidden_states, inference_params) 214 elif self.bimamba_type == "v2": 215 A_b = -torch.exp(self.A_b_log.float()) --> 216 out, xai_a = mamba_inner_fn_no_out_proj( 217 xz, 218 self.conv1d.weight, 219 self.conv1d.bias, 220 self.x_proj.weight, 221 self.dt_proj.weight, 222 A, 223 None, # input-dependent B 224 None, # input-dependent C 225 self.D.float(), 226 delta_bias=self.dt_proj.bias.float(), 227 delta_softplus=True, 228 ) 229 out_b, xai_b = mamba_inner_fn_no_out_proj( 230 xz.flip([-1]), 231 self.conv1d_b.weight, (...) 240 delta_softplus=True, 241 ) 243 xai_vector = xai_a["xai_vector"]

File ~/vision_mamba/HiddenMambaAttn/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py:640, in mamba_inner_fn_no_out_proj(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) 635 def mamba_inner_fn_no_out_proj( 636 xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 637 A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, 638 C_proj_bias=None, delta_softplus=True 639 ): --> 640 return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 641 A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, *kwargs) 536 if not torch._C._are_functorch_transforms_active(): 537 # See NOTE: [functorch vjp and autograd interaction] 538 args = _functorch.utils.unwrap_dead_wrappers(args) --> 539 return super().apply(args, **kwargs) # type: ignore[misc] 541 if cls.setup_context == _SingleLevelFunction.setup_context: 542 raise RuntimeError( 543 "In order to use an autograd.Function with functorch transforms " 544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 545 "staticmethod. For more details, please see " 546 "https://pytorch.org/docs/master/notes/extending.func.html" 547 )

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py:113, in custom_fwd..decorate_fwd(*args, *kwargs) 111 if cast_inputs is None: 112 args[0]._fwd_used_autocast = torch.is_autocast_enabled() --> 113 return fwd(args, **kwargs) 114 else: 115 autocast_context = torch.is_autocast_enabled()

File ~/vision_mamba/HiddenMambaAttn/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py:177, in MambaInnerFnNoOutProj.forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl) 175 x, z = xz.chunk(2, dim=1) 176 conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None --> 177 conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) 178 # We're being very careful here about the layout, to avoid extra transposes. 179 # We want delta to have d as the slowest moving dimension 180 # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 181 x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)

TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:

  1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: Optional[torch.Tensor], arg5: Optional[torch.Tensor], arg6: bool) -> torch.Tensor

Invoked with: tensor([[[-0.7431, -1.0921, -1.0823, ..., -6.5269, -5.7147, -6.1725], [ 2.2745, 1.8768, 1.9875, ..., 0.7682, 0.9121, 1.0120], [-2.1336, -2.2527, -2.5087, ..., -2.8356, -2.6689, -0.3411], ..., [ 1.4074, 1.4416, 1.3563, ..., 0.4989, 0.5790, 0.7334], [-0.4414, -0.4037, -0.4137, ..., -0.8489, -1.4644, 2.3740], [ 0.0253, 0.0353, 0.1163, ..., 0.5299, 0.5381, -1.1345]]], device='cuda:0', requires_grad=True), tensor([[-1.5810e-02, -1.3943e-01, -2.3509e-01, -1.4676e-01], [ 7.0238e-03, 4.7086e-03, -5.2887e-03, 4.2926e-02], [-2.9548e-02, -2.6927e-03, 5.9271e-02, -7.8985e-02], ..., [ 1.1125e-02, 1.0792e-02, -4.0651e-02, 2.3717e-02], [ 1.0528e-03, 6.2020e-03, 2.3909e-01, -2.1156e-02], [ 1.2499e-04, 2.5548e-03, 1.4658e-01, -3.6413e-03]], device='cuda:0', requires_grad=True), Parameter containing: tensor([-1.0635e+00, 2.4638e-01, -7.9336e-02, -3.0806e-03, -8.5145e-03, 9.1366e-02, -4.0313e-03, 3.1901e-03, -3.3325e-01, -6.6671e-03, -2.7176e-03, 1.2741e-02, 1.0470e-01, 2.2530e-01, -3.7955e-03, -6.2425e-03, 1.2493e-01, 6.2957e-03, -2.4233e-03, 2.6721e-01, -7.8681e-03, 2.4848e-01, -2.4029e-03, 1.4098e-02, -5.5850e-02, -1.4354e+00, 1.7973e-01, 2.6566e-01, -1.3597e-03, 2.2743e-01, 1.8952e-02, 3.4875e-03, -3.1521e-03, 1.7774e-02, 1.7449e-03, 7.2191e-02, 3.9825e-03, -1.0589e-02, 5.5537e-01, -4.0390e-03, -7.6204e-03, 3.4693e-01, -4.8061e-02, -6.4325e-03, 2.8420e-02, 1.9012e-03, -7.4139e-02, 1.2716e-01, 4.2100e-02, 1.8722e-02, -6.9467e-02, 1.2283e-01, -2.5192e-01, -1.2641e-02, -4.6429e-02, 1.1435e-01, 1.6297e+00, 2.2173e-02, -3.5596e-03, -1.1807e-01, 1.6317e-01, 1.0687e-01, -7.3275e-03, 2.0508e-03, 7.2347e-02, -6.4788e-02, -1.8683e-01, 2.0344e-01, 1.9531e-01, -4.3225e-03, -5.1714e-02, 1.1987e+00, 1.3818e-01, 5.8801e-01, 2.0633e-03, 2.9770e-01, 1.0638e-02, -1.8109e-01, 1.3473e-03, 4.0542e-01, 4.9738e-02, 5.1594e-02, 2.6039e-01, -7.6944e-03, -2.6762e-03, -3.8922e-03, -2.0999e-02, -9.5488e-02, 1.4749e-01, -6.4690e-03, -2.7964e-04, 8.3897e-02, -1.6851e-02, -2.0919e-03, -2.0535e-02, 1.5519e-01, -3.3982e-02, 8.8583e-02, -3.0490e-01, -1.4437e+00, 1.3884e+00, 9.2825e-03, -1.3492e-01, -6.5016e-03, 7.5881e-04, 2.1717e-01, 1.3165e-03, -9.2841e-04, -1.1407e-02, -2.8154e-02, -4.6151e-02, -4.5009e-03, 2.8528e-02, 7.9720e-02, -1.3120e-04, 9.3855e-01, 1.6067e-02, 7.2818e-02, -1.4436e-01, 4.8675e-03, -4.4924e-02, 6.7125e-01, -1.1070e-04, 8.9302e-03, -1.3241e-01, 5.2655e-03, -9.5995e-03, 2.0920e-01, 9.6696e-02, 6.4028e-03, -6.5375e-02, -1.1691e-02, 1.7332e-03, 4.5609e-01, -8.8611e-03, -1.5471e-02, 1.0586e-01, 1.9891e-01, 3.2027e-01, -7.7983e-03, 4.2218e-02, -1.8413e+00, -4.2834e-02, 2.6168e-03, -2.3236e-03, -1.6963e-01, -1.5538e-01, 5.9564e-01, -1.3132e-03, 5.2598e-02, 2.3673e-01, -6.2803e-03, 1.6469e-02, 3.1585e-03, 1.0183e-01, 1.8089e-02, 1.0001e+00, 6.2215e-02, 3.0459e-01, -1.5413e-03, -2.6634e-01, -1.6901e-01, 1.8580e-01, -1.1413e-03, 8.8751e-02, 6.0080e-01, -6.8678e-02, -7.8265e-02, -1.5858e-02, -2.4146e-01, 6.1673e-02, -1.7811e-01, 2.1447e-01, 5.9734e-04, -1.6788e-03, -4.0475e-01, 2.4137e-02, 6.6903e-03, -3.6503e-03, -2.2009e-01, -3.9217e-02, -7.8736e-02, 9.6530e-03, 3.0848e-03, -7.0898e-03, -1.6326e-01, -4.3634e-03, -4.1274e-01, 1.5170e-01, -1.5018e-01, -3.6604e-03, -9.7680e-03, 7.2088e-03, -2.1073e-03, 2.4912e-01, 4.8137e-03, -8.6577e-02, -2.2246e-01, -1.5529e-01, -4.3781e-01, -6.7392e-03, -7.4789e-03, 1.7655e-01, 5.0102e-01, 1.4193e-01, -2.9938e-01, -2.2995e-01, -1.1735e-02, -9.0624e-02, 1.3724e-02, 4.2021e-01, 2.2302e-02, -5.9316e-03, -8.0328e-03, -2.5244e-02, 2.1770e-03, -5.0969e-02, 1.6047e-01, 7.5966e-01, -3.5215e-01, -4.2127e-02, -2.9970e-03, 7.1533e-02, 9.3152e-02, -6.2179e-01, 1.5942e-03, -9.1926e-03, -7.1967e-03, 6.9556e-02, -8.7444e-03, -2.1567e-02, 1.2076e-01, -1.6350e-01, 5.7466e-02, -1.6737e-01, -1.2882e+00, -1.4022e-01, 1.1180e-02, -4.8682e-04, -3.8524e-03, 1.0495e+00, 2.4244e-01, 3.5480e-01, 3.0791e-02, -1.9133e-03, 2.7686e-03, -3.0201e-01, -2.7361e-02, -9.7925e-01, 1.5099e-01, -4.1309e-03, 1.2675e-01, 1.1467e-02, -4.1115e-03, 4.0145e-01, 5.0611e-02, 9.4143e-03, -8.3112e-02, -2.8039e-01, -7.2283e-02, -1.0863e-02, -1.4394e-02, 1.2533e-01, -1.6098e-02, 6.4462e-02, 9.6746e-02, 4.9365e-02, -1.3432e-02, 6.1394e-02, 1.6271e-01, -1.7361e-02, -3.9885e-03, 1.0242e-03, -6.2845e-02, 5.3738e-02, -2.5145e-03, 3.4229e-01, -1.5990e-02, 8.8808e-02, 2.3910e-01, 2.9955e-01, -2.5723e-03, -3.9782e-01, -6.8439e-02, -2.0935e-02, 1.3499e-02, 1.0948e-02, -2.3410e-02, -6.8242e-02, -6.3721e-02, 2.9660e-01, 8.4169e-04, 1.2700e-03, -5.3214e-03, 1.5132e+00, -4.8847e-03, 6.9353e-03, -2.8305e-04, -2.1873e-01, 2.6341e-01, 3.4480e-02, -2.1210e-02, 1.8489e-03, 7.9293e-03, 1.2968e-02, -7.0947e-03, 2.4959e-01, 1.4325e-01, -3.8781e-02, -1.2941e-03, 8.9719e-03, 2.8145e-01, 5.0440e-02, -1.2898e-02, -2.7385e-02, 6.1853e-03, 9.2878e-03, -2.2118e-03, -2.4606e-03, 1.1930e-02, 1.5756e-01, 1.8180e-01, -4.8295e-04, -8.8863e-03, -1.4038e-04, 1.2096e-01, -1.2303e-02, -3.8137e-03, -2.7396e-02, 7.0784e-02, 5.9082e-01, -5.3756e-03, 1.3792e-02, 7.2420e-04, 3.6899e-01, -7.7250e-02, 2.5091e-02, 6.2634e-04, 3.0072e-02, 3.1364e-02, 3.5177e-03, -1.9227e-02, -1.0111e-02, -6.5878e-02, -1.3654e-02, 3.7002e-02, -4.2233e-01, -8.2767e-03, 3.1406e-01, -5.8141e-03, 2.6331e-02, -1.8383e-01, 2.8965e-02, -9.1774e-03, -1.3667e-02, 6.0254e-02, -7.1197e-01, 2.3181e-03, 1.3239e-01, -8.9811e-04, 6.9458e-01, 1.6780e-02, 2.0594e-01, 1.1398e-01, -8.9957e-03, 2.6240e-01, 6.3750e-02, -8.0824e-03, 3.8497e-02, 2.7752e-01, 3.5789e-02, -5.2097e-03, -1.5926e-02, 2.4415e-01, -3.0147e-01, -4.3478e-03, 9.8203e-02, 8.5932e-02, 5.3448e-01, 7.3268e-03, -5.3267e-02, 4.4600e-03, 3.3561e-01, -3.1493e-02, 3.8527e-01, 7.8586e-03, 1.8595e-01, -4.4578e-03, -1.1432e-02, 2.3568e-01, -1.5853e-02, 2.7142e-01, 1.1630e-01, 1.3323e-01, -1.4568e+00, -7.2974e-02, -4.6991e-01, -3.5383e-01, -3.8902e-03, 1.1473e-01, 2.6471e-01, -1.3818e-02, 3.4047e-01, -2.6703e-01, 2.1927e-01, -2.5097e+00, 3.5238e-01, 3.3866e-01, -1.8838e-02, -2.4029e-02, -1.2514e-01, -1.8100e+00, 1.2491e-01, 6.1557e-03, 2.5506e-01, 4.3953e-01, -6.7747e-03, -1.3343e-01, 4.9651e-03, 1.1069e-02, -6.8971e-03, 2.2349e-03, -1.3898e-02, 3.1381e-02, -1.0452e+00, -3.4032e-03, 1.9777e-02, 8.6182e-03, -6.2949e-02, -4.6400e-03, -1.9578e-01, -2.5346e-03, -2.6585e-02, -1.8469e-01, 6.8180e-02, -2.1834e-02, 3.2416e-01, 1.4120e-01, 9.9826e-03, 2.1723e-01, 1.0145e-02, -1.7804e-02, 9.1178e-01, 3.1059e-01, 2.7338e-02, -1.3966e-03, 1.6731e-01, -1.1882e-01, -5.4713e-03, 3.3315e-01, -6.0351e-04, 4.7710e-03, 6.7300e-02, -2.5403e-03, 1.5854e-01, 1.0069e-02, 2.4024e-02, -4.0396e-03, 1.4800e-01, -3.5548e-04, 3.7993e-01, -1.0609e-01, 1.7124e-02, 2.6212e-02, -2.5941e-03, -2.5954e-01, 4.8984e-01, 5.2860e-03, -3.6842e-01, -6.3844e-03, 1.6904e-03, -5.2319e-02, -2.8476e-02, -2.0648e-03, -1.4372e-02, 9.4359e-03, -3.2128e-03, 5.2755e-03, 3.8065e-01, 1.6462e+00, 5.5487e-01, 3.4897e-01, 1.6634e-01, -5.4256e-02, 3.7283e-03, 7.2037e-02, -3.5004e-03, 3.4959e-01, 9.9024e-01, 4.9238e-01, 1.5224e-03, -2.0089e-01, 9.1501e-02, -8.8233e-02, -3.6224e-01, 3.5101e-02, -2.1377e-02, 5.8791e-01, -4.2889e-03, -2.1675e-03, 8.6368e-04, -8.5778e-02, -1.4125e-02, 6.7315e-02, -8.8937e-02, 8.6234e-03, -3.6201e-03, -5.7523e-03, 8.7393e-02, -1.3330e-03, 4.4445e-03, -7.1089e-03, -5.7517e-02, -7.3739e-01, -5.4790e-03, 5.7042e-01, -6.9731e-01, 4.5690e-01, 8.0069e-03, 2.1896e-02, 3.8009e-01, 9.9928e-02, 2.9393e-01, -5.0244e-03, -1.2556e-01, 1.1275e-01, -3.6600e-02, 4.0938e-01, 4.7893e-02, 3.0275e-02, 7.3696e-01, -8.9674e-03, -1.5709e-02, 1.5579e-02, -4.2393e-03, 1.6116e+00, 8.2845e-05, 6.8535e-01, 2.9223e-03, 2.5655e-02, -3.7072e-02, -4.9986e-02, -7.0963e-03, -7.7315e-03, 5.5129e-01, 3.1589e-03, 1.4741e-01, 6.6661e-04, 2.1577e-01, -1.4496e-02, -2.4490e-02, -7.9089e-03, -7.9964e-02, -5.8160e-01, 1.6533e+00, -1.3461e-02, 1.0738e-01, -6.2257e-03, 5.9326e-01, -1.0073e-02, -2.3946e-03, 3.4121e-01, -2.7001e-02, 3.3991e-03, -1.0180e+00, 1.9100e+00, 2.1304e-01, -6.6947e-01, 3.7764e-02, 1.8430e-03, -7.4254e-03, 2.0522e-02, -6.3795e-04, 2.1389e-01, 1.1938e-01, 3.7763e-03, 1.6136e-01, 2.3371e-03, 3.7377e-01, -3.6651e-02, -2.1627e-02, -4.8879e-03, 2.0994e-02, 9.9277e-03, -5.5656e-04, 4.3180e-01, -6.4651e-03, -5.1930e-03, 1.4043e-03, -7.6381e-02, 3.4770e-04, -1.1894e-01, -1.6603e-01, 3.8657e-02, 9.6827e-02, 5.0529e-02, -9.0310e-03, -1.4857e-02, -9.5627e-03, 1.1346e-03, -3.1490e-03, -1.4966e-03, 1.9337e-01, -8.3898e-01, -5.1107e-04, -1.1920e+00, 3.7499e-01, 6.6732e-03, 3.4931e-02, -2.5154e-01, -3.6595e-03, -6.1261e-01, 7.6277e-03, 6.6109e-02, -1.2395e-02, 2.1371e-02, -2.7120e-03, 1.7438e-02, 2.6483e-03, -2.3949e-02, -3.7907e-01, -1.0906e-03, 3.2354e-02, -1.1795e-02, 1.1919e-01, 1.7215e-01, -3.0615e-02, -4.0068e-01, -1.6082e-02, -2.3786e-02, 3.4010e-02, -1.2291e-02, -1.1197e+00, 1.1192e-01, 2.5026e-03, 4.8931e-02, 1.0807e-01, 7.3450e-03, 9.2780e-02, 9.4775e-01, -2.7132e-02, 5.6188e-01, -2.9372e-01, 3.7717e-01, 1.3087e-03, -6.9811e-03, 4.9482e-01, 2.4433e-01, 1.0345e-01, -3.7036e-01, 1.1386e-03, -6.2356e-02, 4.7152e-01, 2.3186e-01, 5.9743e-02, 3.2515e-02, -5.2630e-03, -2.6451e-03, -2.0303e-03, -1.6511e-02, 2.1712e-06, -4.8726e-02, -7.5366e-03, 2.1110e+00, -2.4448e-03, -5.5772e-03, 9.1061e-02, 1.1589e-01, -6.9149e-01, -1.4870e-02, -1.0735e-01, 2.6281e-03, -8.2189e-03, -2.2700e-02, 3.2095e-01, 6.0523e-03, 4.4393e-01, -1.6221e+00, -1.3218e-03, 1.6070e-01, 6.3236e-02, 1.5166e-01, -3.6593e-03, 1.9573e-03, 2.0225e-01, 8.8925e-04, -8.4206e-03, 3.7580e-01, 1.7476e-01, 3.4807e-02, -1.0749e-02, 5.0454e-01, -6.3422e-02, -2.8144e-01, 1.2235e-01, -6.0315e-02, -7.6801e-04, 7.3049e-02, -4.4833e-03, -2.9738e-01, -5.0778e-03, 4.3110e-01, -1.6257e-02, 1.1173e-01, 1.1282e-01, -2.8902e-03, 8.1538e-02, 1.9677e-01, 4.8990e-01, 3.7502e-03, -9.0393e-03, 2.4570e-02, -6.0553e-03, 1.1458e-02, 5.6039e-01, -5.2616e-03, 3.5473e-01, -8.0808e-02, 3.2868e-02, -3.7920e-03, 6.1825e-02, 3.5697e-02, 2.3548e-02, -1.0562e-04, -4.6733e-02, 8.0036e-03, -2.8285e-02, 3.5382e-02, 1.7734e-02, -8.0451e-04, 1.1438e-01, -6.7491e-03, 1.6145e-04, 7.9055e-02, 3.1657e-01, -5.1723e-03, -1.1862e-02, 2.1521e+00, -5.3933e-03, -2.5797e-02, -3.7987e-03, -8.1366e-03, -1.8911e-01, 2.3091e-01, 3.2193e-01, 5.4230e-03, 1.0543e-02, 4.4462e-03, 2.5824e-02, -1.5587e-03, 1.1629e+00, -4.0007e-03, 4.5036e-03, 2.8761e-01, -1.7879e-02, -1.1815e-02, -1.5913e-03, 1.1076e-01, 2.6805e+00, -1.3449e-01, -3.8301e-02, -3.0200e-02, 3.7410e-01, -3.2197e-01, -1.8442e-01, 2.9141e-01, 6.2490e-03, 2.8774e-03], device='cuda:0', requires_grad=True), True

image = transform_for_eval('./images/2.jpg').unsqueeze(0).cuda()

Itamarzimm commented 9 months ago

Thank you! First, could you please confirm that you installed causal-conv1d from our source using the -e (editable) option, as specified in the instructions? This should solve the problem. Additionally, please note that this issue also exists in the original repository of Vision Mamba (for more details, see https://github.com/hustvl/Vim/issues/34.)

sivaji123256 commented 9 months ago

@Itamarzimm ,Thanks for your immediate response. It is given that causal-conv1d==1.1.0. Pls highlight that in your readme. I have directly installed using the instructions without considering the version. I will try reinstalling the required version and update here.

Itamarzimm commented 9 months ago

Sure. Additionally, this issue is likely related to the fix documented here: https://github.com/AmeenAli/HiddenMambaAttn/pull/1#issue-2170023219. It should be resolved now. Thanks, @bhoov!

sivaji123256 commented 9 months ago

Thanks @Itamarzimm .It was resolved.

sivaji123256 commented 9 months ago

@Itamarzimm ,I was also trying to produce the same for ViT small model you mentioned in the pdf attached in readme.Could you just guide me to replicate the same for ViT Small.Thanks in advance.

AmeenAli commented 9 months ago

@sivaji123256 for the ViT models, you can use this repo Transformer-Explainability