QinSY123 / 2024-MambaVC

Code for MambaVC: Learned Visual Compression with Selective State Spaces
20 stars 0 forks source link

Do I need to load a pre-trained model for the project to run? #5

Open shuodehaoa opened 1 week ago

shuodehaoa commented 1 week ago

The parameter --checkpoint in the code run command that seems to be required to specify the path to the pre-trained model?

QinSY123 commented 1 week ago

That is an optional option, not required at the beginning of training.

shuodehaoa commented 1 week ago

2024-06-25 21:50:50,403 [INFO ] Logging file is /nas/project/2024-MambaVC/checkpoints/0.05/20240625_215050.log 2024-06-25 21:50:50,404 [INFO ] model:bmshj2018-factorized 2024-06-25 21:50:50,404 [INFO ] dataset:/nas/dataset/QP22/Split/mambaVC 2024-06-25 21:50:50,404 [INFO ] epochs:500 2024-06-25 21:50:50,404 [INFO ] learning_rate:0.0001 2024-06-25 21:50:50,404 [INFO ] num_workers:128 2024-06-25 21:50:50,404 [INFO ] lmbda:0.05 2024-06-25 21:50:50,404 [INFO ] batch_size:8 2024-06-25 21:50:50,404 [INFO ] test_batch_size:1 2024-06-25 21:50:50,404 [INFO ] aux_learning_rate:0.001 2024-06-25 21:50:50,404 [INFO ] patch_size:(256, 256) 2024-06-25 21:50:50,404 [INFO ] cuda:True 2024-06-25 21:50:50,404 [INFO ] save:True 2024-06-25 21:50:50,404 [INFO ] seed:42 2024-06-25 21:50:50,404 [INFO ] clip_max_norm:1.0 2024-06-25 21:50:50,404 [INFO ] checkpoint:None 2024-06-25 21:50:50,404 [INFO ] type:mse 2024-06-25 21:50:50,404 [INFO ] save_path:/nas/project/2024-MambaVC/checkpoints 2024-06-25 21:50:50,404 [INFO ] skip_epoch:0 2024-06-25 21:50:50,404 [INFO ] N:128 2024-06-25 21:50:50,404 [INFO ] lr_epoch:[450, 490] 2024-06-25 21:50:50,404 [INFO ] continue_train:False cuda /home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 128 worker processes in total. Our suggested max number of worker in current system is 28, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( milestones: [450, 490] 0%| | 0/500 [00:00<?, ?it/s]2024-06-25 21:50:51,487 [INFO ] ======Current epoch 0 ====== 2024-06-25 21:50:51,487 [INFO ] Learning rate: 0.0001 0it [00:00, ?it/s] 0%| | 0/500 [00:03<?, ?it/s] Traceback (most recent call last): File "train.py", line 466, in main(sys.argv[1:]) File "train.py", line 431, in main train_one_epoch( File "train.py", line 147, in train_one_epoch out_net = model(d) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/nas/project/2024-MambaVC/models/MambaVC.py", line 1300, in forward y = self.g_a(x) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/nas/project/2024-MambaVC/models/MambaVC.py", line 1151, in forward x = self._forward(x).permute(0, 3, 1, 2) File "/nas/project/2024-MambaVC/models/MambaVC.py", line 1143, in _forward x = input + self.drop_path(self.op(self.norm(input))) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/nas/project/2024-MambaVC/models/MambaVC.py", line 1021, in forward y = self.forward_core(x) File "/nas/project/2024-MambaVC/models/MambaVC.py", line 930, in forward_corev2 return cross_selective_scan( File "/nas/project/2024-MambaVC/models/MambaVC.py", line 480, in cross_selective_scan ys: torch.Tensor = selective_scan( File "/nas/project/2024-MambaVC/models/MambaVC.py", line 446, in selective_scan return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex) File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/autograd/function.py", line 553, in apply return super().apply(args, kwargs) # type: ignore[misc] File "/home/gao/anaconda3/envs/mambaVC/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd return fwd(*args, *kwargs) File "/nas/project/2024-MambaVC/models/MambaVC.py", line 521, in forward out, x, rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) TypeError: fwd(): incompatible function arguments. The following argument types are supported:

  1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: bool, arg8: int, arg9: bool) -> List[torch.Tensor]

Invoked with: tensor([[[ 2.5545e-01, 3.2468e-01, 3.1105e-01, ..., 1.1445e-01, 1.1466e-01, 1.1707e-01], [ 2.2994e-01, 2.8560e-01, 2.1514e-01, ..., 2.9078e-01, 3.4029e-01, 3.0308e-01], [ 7.6905e-02, 8.9864e-02, 1.0398e-01, ..., -6.8421e-02, -7.0336e-02, -3.3716e-02], ..., [ 2.7783e-02, 6.0757e-02, 1.5139e-01, ..., 7.5845e-02, 1.5906e-02, 9.4314e-02], [-1.2216e-02, 7.2406e-02, 1.3762e-01, ..., 2.0739e-02, 1.8137e-02, 1.6718e-02], [ 6.8043e-02, 8.0412e-02, 5.6260e-02, ..., 3.7458e-02, 2.4173e-02, 4.2964e-02]],

    [[ 2.6391e-01,  3.4885e-01,  3.1744e-01,  ...,  9.8407e-02,
       1.0396e-01,  1.1010e-01],
     [ 2.2787e-01,  2.8049e-01,  2.5370e-01,  ...,  2.9538e-01,
       3.1826e-01,  3.1299e-01],
     [ 6.8315e-02,  7.2892e-02,  1.1020e-01,  ..., -7.1713e-02,
      -8.1271e-02, -4.5279e-02],
     ...,
     [ 2.6714e-02,  6.7872e-02,  1.5585e-01,  ...,  8.8021e-02,
       8.6730e-03,  7.6960e-02],
     [-2.5178e-02,  6.1933e-02,  9.9606e-02,  ..., -1.2581e-02,
      -3.6286e-03,  6.7224e-03],
     [ 6.7677e-02,  8.8382e-02,  6.3443e-02,  ...,  1.5692e-02,
      -4.9952e-03,  3.0603e-02]],

    [[ 2.4910e-01,  3.2694e-01,  2.9670e-01,  ...,  1.1882e-01,
       1.1018e-01,  1.1878e-01],
     [ 2.1325e-01,  2.7530e-01,  2.4434e-01,  ...,  2.9738e-01,
       3.5416e-01,  3.1198e-01],
     [ 6.0548e-02,  7.1420e-02,  1.0630e-01,  ..., -7.2580e-02,
      -7.3519e-02, -3.7565e-02],
     ...,
     [ 2.1059e-02,  7.0689e-02,  1.6398e-01,  ...,  7.0748e-02,
       1.8903e-02,  9.4413e-02],
     [-1.9248e-02,  8.2502e-02,  1.4179e-01,  ...,  2.1170e-02,
       5.5312e-02,  1.8289e-02],
     [ 6.7050e-02,  7.8700e-02,  5.8725e-02,  ...,  3.4673e-02,
       2.3773e-02,  3.2414e-02]],

Can you give any advice on this error?