zmoka-zht / CDMamba

38 stars 3 forks source link

训练报错 #9

Closed YonghuiTAN22 closed 1 week ago

YonghuiTAN22 commented 1 month ago

Traceback (most recent call last): File "/home/yonghui/project/CDMamba/train_cd.py", line 121, in pred_img = cd_model(train_im1, train_im2) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 183, in forward return self.module(inputs[0], module_kwargs[0]) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/yonghui/project/CDMamba/models/CDMamba.py", line 359, in forward x1, down_x1 = self.encode(x1) File "/home/yonghui/project/CDMamba/models/CDMamba.py", line 343, in encode x = down(x) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward input = module(input) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/yonghui/project/CDMamba/models/CDMamba.py", line 112, in forward x = self.conv1(x) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/yonghui/project/CDMamba/models/CDMamba.py", line 55, in forward x_mamba = self.convmamba(x_norm) + self.skip_scale x_flat File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/yonghui/project/CDMamba/models/mamba_customer.py", line 231, in forward out = mamba_inner_fn_no_out_proj( File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 632, in mamba_inner_fn_no_out_proj return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, *kwargs) # type: ignore[misc] File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd return fwd(args, **kwargs) File "/home/yonghui/miniconda3/envs/cdmamba/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 177, in forward conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True) TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:

  1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: Optional[torch.Tensor], arg5: Optional[torch.Tensor], arg6: bool) -> torch.Tensor

Invoked with: tensor([[[-0.3382, 1.2468, 1.2655, ..., 0.4876, 0.3880, -0.2435], [ 0.3241, -0.0224, -0.0093, ..., 0.1298, 0.2758, 0.1078], [-0.2305, 0.4938, 0.5060, ..., -0.8157, -0.7583, -0.7534], ..., [ 0.0138, -0.4216, -0.4223, ..., 0.4248, 0.3596, 0.8166], [-0.3435, -0.3589, -0.3545, ..., -0.3836, -0.2709, -0.4491], [ 0.3027, -0.2779, -0.3023, ..., -1.4457, -1.5415, -0.8076]],

    [[ 0.3048,  0.1459, -0.1444,  ...,  0.4878,  0.1939,  0.1200],
     [ 0.3129,  0.0807, -0.1589,  ...,  0.5834,  0.4714,  0.6267],
     [-1.6001, -1.5727, -1.1166,  ..., -1.5087, -1.4359, -1.2867],
     ...,
     [-0.0175,  0.0526,  0.1447,  ...,  0.5635,  0.6297,  0.7081],
     [ 0.2431,  0.4354,  0.4448,  ..., -0.4660, -0.5309, -0.4528],
     [-0.0063,  0.4734,  1.0196,  ..., -0.8193, -0.7774, -0.5409]],

    [[-0.4493, -0.8567, -0.3965,  ..., -0.1048, -0.3876, -0.5364],
     [-0.6204, -0.3849, -0.3530,  ...,  0.3625,  0.2440,  0.3402],
     [-0.6707, -0.2034, -1.2349,  ..., -0.3889, -0.4152, -0.4898],
     ...,
     [-0.4454, -0.3934, -0.3116,  ..., -0.3986,  0.4184,  0.7528],
     [ 0.9115,  0.4804,  0.5283,  ..., -0.3985, -0.6552,  0.0048],
     [ 0.8110,  1.1925,  0.2142,  ...,  0.7225,  0.5372,  0.9991]],

    [[ 0.1460, -0.3311,  0.5524,  ...,  0.1315,  0.3041,  0.1224],
     [ 0.0689,  0.1523,  0.6041,  ...,  0.7305,  0.8529,  0.5102],
     [-1.4563, -1.2934,  0.1552,  ..., -0.6928, -0.3413, -0.4378],
     ...,
     [-0.3893,  0.5175, -0.4975,  ...,  0.0043, -0.0401,  0.4407],
     [-0.3009, -0.3316,  0.1149,  ..., -0.8992, -0.8131, -0.2438],
     [-1.1128, -0.4745, -0.2984,  ..., -0.2784, -0.1001,  0.6662]],

    [[ 0.3559,  1.3747,  0.5877,  ...,  0.1049,  0.0235, -0.1569],
     [ 1.3827,  0.2189,  0.9957,  ...,  0.6291,  0.7082,  0.0795],
     [-1.4018, -1.3035, -1.3663,  ..., -0.8104, -0.7798, -1.0064],
     ...,
     [-0.8558, -0.5822,  0.0723,  ...,  0.1590,  0.0742,  0.6358],
     [-0.2536,  0.1003,  0.0701,  ..., -0.2895, -0.3217, -0.4538],
     [-1.5195, -1.1477, -0.3632,  ..., -1.6496, -1.6862, -0.8814]],

    [[-0.6123,  1.3670,  0.9122,  ...,  0.4825,  0.3046, -0.0215],
     [ 0.0415, -0.2664, -0.0817,  ...,  0.0427,  0.0627,  0.0251],
     [-0.3807,  0.8442,  0.6802,  ...,  0.5406,  0.3389,  0.2427],
     ...,
     [ 0.4262, -0.2382, -0.0265,  ..., -0.1674, -0.3640, -0.3335],
     [-0.3091, -0.3754,  0.0629,  ..., -0.1278, -0.0883, -0.4465],
     [ 0.6043, -0.5702, -0.2293,  ...,  0.5442,  0.6252,  0.7977]]],
   device='cuda:0', requires_grad=True), tensor([[ 4.7082e-02,  4.1346e-01, -3.3444e-01,  2.9570e-01],
    [-3.0883e-02, -3.0925e-01,  4.7123e-01,  2.4091e-01],
    [ 1.0045e-01, -3.5421e-01,  3.5039e-01, -2.9597e-01],
    [-1.0580e-01, -4.6909e-01,  1.3028e-01, -4.6993e-02],
    [ 1.9503e-02,  1.0519e-01,  4.6308e-01, -1.6372e-01],
    [-3.1330e-01,  8.5172e-03,  1.2206e-02,  1.2029e-01],
    [-5.9586e-02,  8.4758e-02,  4.4922e-01,  1.4817e-01],
    [ 2.5536e-02, -2.4294e-01,  4.1138e-01,  2.5648e-02],
    [ 3.8002e-01, -3.7613e-01, -5.5295e-02,  3.7601e-01],
    [-1.5758e-01,  4.5597e-01,  3.2043e-01,  3.6196e-01],
    [ 4.1267e-01,  4.3233e-01,  2.2660e-01, -4.9237e-01],
    [-1.5011e-01, -1.0473e-01, -2.0259e-01, -3.7253e-01],
    [-3.4709e-01,  4.8554e-01,  1.2161e-01, -4.7074e-01],
    [ 4.5883e-01,  1.7612e-01, -6.9445e-03,  1.3495e-01],
    [-6.3051e-02, -8.2720e-02, -5.1621e-02, -2.9443e-01],
    [ 3.1236e-01,  1.2674e-01,  1.9202e-01,  2.3875e-01],
    [ 4.4673e-01,  3.7752e-01,  7.8761e-02,  2.0817e-01],
    [-1.4266e-01, -8.3402e-02, -4.8220e-01, -4.7646e-03],
    [-1.0514e-01,  2.5823e-01,  3.6129e-01, -2.5278e-02],
    [ 3.9872e-01,  2.9057e-01, -1.4850e-01,  2.2747e-01],
    [ 7.8855e-02,  2.8368e-01,  1.8814e-01, -3.2391e-01],
    [-5.7115e-02,  4.8441e-01, -1.6526e-01, -1.8688e-01],
    [-3.6753e-01,  4.9411e-01,  1.7309e-01,  4.1119e-01],
    [ 4.0213e-01, -3.6803e-01,  2.7032e-01, -3.3465e-02],
    [-4.6479e-01,  7.8191e-02,  3.6651e-01, -4.3260e-01],
    [-2.4505e-02,  1.5999e-01, -2.0513e-01,  7.1506e-02],
    [ 4.8376e-01,  1.3882e-01, -2.2149e-01,  6.5740e-02],
    [ 4.8058e-01, -3.9497e-02, -1.0274e-01, -2.1423e-02],
    [-5.7657e-02, -1.4914e-01,  1.5792e-01,  4.1613e-01],
    [-4.2348e-01,  3.8299e-01, -4.4358e-04, -2.3955e-01],
    [ 1.2597e-01, -3.5201e-01,  2.6907e-01,  3.5920e-01],
    [-3.8491e-01, -6.8774e-02, -1.2165e-04, -3.2313e-01]], device='cuda:0',
   requires_grad=True), Parameter containing:

tensor([-0.4909, -0.4263, 0.3008, -0.1900, 0.4366, -0.2839, -0.1585, -0.3395, 0.0553, 0.1627, -0.0579, 0.4827, -0.1528, -0.1705, 0.4135, -0.0235, -0.4720, 0.2922, -0.3589, -0.3827, -0.2841, -0.3576, -0.2487, 0.3341, -0.3817, -0.4225, 0.3300, -0.2231, -0.0936, 0.4924, -0.2312, 0.1307], device='cuda:0', requires_grad=True), None, True

zmoka-zht commented 3 weeks ago

这个看起来似乎是Mamba环境中causal_conv1d出现了错误。实验环境是否安装正确?