Algolzw / BSRT

Pytorch code for "BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment", CVPRW, 1st place in NTIRE 2022 BurstSR Challenge (real-world track).
MIT License
179 stars 13 forks source link

Training without access to Zurich RAW to RGB (ZRR) validation #4

Closed ConnorBaker closed 1 year ago

ConnorBaker commented 1 year ago

Hello,

Thank you for sharing your code!

I have a fork (https://github.com/ConnorBaker/BSRT) where I updated CUDA/PyTorch and switched to the DCN implementation offered by torchvision (details here: https://pytorch.org/vision/stable/generated/torchvision.ops.deform_conv2d.html#torchvision.ops.deform_conv2d). However, torchvision's DCN expects the second dimension of weights to be divided by the number of groups, which causes an error when trying to use the pre-trained model, since the sizes of the tensors don't match. I tried to work around this so I could still use your pre-trained model, but was unable to: see https://github.com/ConnorBaker/BSRT/commit/7d9c5bdb873693109fab2a8217e482f0cf5c6ee9 for details and https://github.com/ConnorBaker/BSRT/commit/8b8339491205bdbddf14ffdea01cc92584e24939 for the full switch over.

How can I train this model? I tried to get access to the full Zurich RAW to RGB (ZRR) dataset, but it is no longer available. Without the validation data (val directory from ZRR), I get errors when I try to run training (see example below).

Any ideas?

Many thanks, Connor


Example run with training:

$ python main.py --n_GPUs 1 --print_every 40 --lr 0.0001 --decay 150-300 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256 --root ~/datasets/zurich-raw-to-rgb --models_root ~/models
train data: 46839, test data: 300
Test only: False
Making model:  BSRT
Patch size:  256
depths:  [6, 6, 6, 6, 6]
using swinfeature
/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484801627/work/aten/src/ATen/native/TensorShape.cpp:2894.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Preparing loss function:
1.000 * L1
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Loading model from: /home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth
number of parameters:  4786641
Fix keys: ['spynet', 'dcnpack'] for the first 5 epochs.
[1280/46839]    [0.0483]    58.5+2.1s
[2560/46839]    [0.0382]    54.9+0.2s
[3840/46839]    [0.0277]    54.9+0.2s
[5120/46839]    [0.0223]    54.9+0.2s
[6400/46839]    [0.0277]    54.8+0.2s
[7680/46839]    [0.0252]    54.9+0.2s
[8960/46839]    [0.0292]    54.8+0.2s
[10240/46839]   [0.0321]    54.9+0.2s
[11520/46839]   [0.0256]    54.9+0.2s
[12800/46839]   [0.0240]    54.9+0.2s
[14080/46839]   [0.0242]    54.9+0.2s
[15360/46839]   [0.0240]    54.9+0.2s
[16640/46839]   [0.0207]    54.9+0.2s
[17920/46839]   [0.0184]    54.9+0.2s
[19200/46839]   [0.0238]    55.0+0.2s
[20480/46839]   [0.0200]    54.9+0.2s
[21760/46839]   [0.0150]    54.8+0.2s
[23040/46839]   [0.0195]    54.8+0.2s
[24320/46839]   [0.0154]    54.8+0.2s
[25600/46839]   [0.0215]    54.8+0.2s
[26880/46839]   [0.0190]    54.9+0.2s
[28160/46839]   [0.0167]    54.8+0.2s
[29440/46839]   [0.0194]    54.8+0.2s
[30720/46839]   [0.0160]    54.8+0.2s
[32000/46839]   [0.0194]    54.8+0.2s
[33280/46839]   [0.0182]    54.8+0.2s
[34560/46839]   [0.0153]    54.8+0.2s
[35840/46839]   [0.0179]    54.8+0.2s
[37120/46839]   [0.0199]    54.9+0.2s
[38400/46839]   [0.0160]    54.8+0.2s
[39680/46839]   [0.0186]    54.9+0.2s
[40960/46839]   [0.0195]    54.9+0.2s
[42240/46839]   [0.0169]    54.9+0.2s
[43520/46839]   [0.0156]    54.9+0.2s
[44800/46839]   [0.0172]    54.8+0.2s
[46080/46839]   [0.0158]    54.9+0.2s
Epoch 1 cost time: 2018.3s, lr: 0.000100
save model...
save model...
Testing...
[ WARN:0@2022.657] global /home/conda/feedstock_root/build_artifacts/libopencv_1661642960203/work/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('/home/connorbaker/datasets/zurich-raw-to-rgb/val/bursts/0000/im_raw_00.png'): can't open/read file: check file path/integrity
[ WARN:0@2022.657] global /home/conda/feedstock_root/build_artifacts/libopencv_1661642960203/work/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('/home/connorbaker/datasets/zurich-raw-to-rgb/val/bursts/0032/im_raw_00.png'): can't open/read file: check file path/integrity
[ WARN:0@2022.658] global /home/conda/feedstock_root/build_artifacts/libopencv_1661642960203/work/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('/home/connorbaker/datasets/zurich-raw-to-rgb/val/bursts/0064/im_raw_00.png'): can't open/read file: check file path/integrity
[ WARN:0@2022.659] global /home/conda/feedstock_root/build_artifacts/libopencv_1661642960203/work/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('/home/connorbaker/datasets/zurich-raw-to-rgb/val/bursts/0160/im_raw_00.png'): can't open/read file: check file path/integrity
[ WARN:0@2022.659] global /home/conda/feedstock_root/build_artifacts/libopencv_1661642960203/work/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('/home/connorbaker/datasets/zurich-raw-to-rgb/val/bursts/0128/im_raw_00.png'): can't open/read file: check file path/integrity
[ WARN:0@2022.659] global /home/conda/feedstock_root/build_artifacts/libopencv_1661642960203/work/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('/home/connorbaker/datasets/zurich-raw-to-rgb/val/bursts/0192/im_raw_00.png'): can't open/read file: check file path/integrity
Traceback (most recent call last):
  File "main.py", line 98, in <module>
[ WARN:0@2022.660] global /home/conda/feedstock_root/build_artifacts/libopencv_1661642960203/work/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('/home/connorbaker/datasets/zurich-raw-to-rgb/val/bursts/0256/im_raw_00.png'): can't open/read file: check file path/integrity
    main()
  File "main.py", line 42, in main
    main_worker(0, args.n_GPUs, args)
  File "main.py", line 85, in main_worker
    t.train()
  File "/home/connorbaker/BSRT/code/synthetic/bsrt/trainer.py", line 228, in train
    self.test()
  File "/home/connorbaker/BSRT/code/synthetic/bsrt/trainer.py", line 263, in test
    for i, batch_value in enumerate(self.loader_valid):
  File "/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data
    return self._process_data(data)
  File "/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data
    data.reraise()
  File "/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/connorbaker/micromamba/envs/bsrt-real/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/connorbaker/BSRT/code/synthetic/bsrt/datasets/synthetic_burst_val_set.py", line 61, in __getitem__
    burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
  File "/home/connorbaker/BSRT/code/synthetic/bsrt/datasets/synthetic_burst_val_set.py", line 61, in <listcomp>
    burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
  File "/home/connorbaker/BSRT/code/synthetic/bsrt/datasets/synthetic_burst_val_set.py", line 33, in _read_burst_image
    im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
AttributeError: 'NoneType' object has no attribute 'astype'
Algolzw commented 1 year ago

Hi, the standard validation dataset for Synthetic training can be downloaded at here.

You can also manually split a validation dataset from the training data in training. But in this way you must evaluate the trained model again on the standard validation dataset.

ConnorBaker commented 1 year ago

Thank you!

May I email you with some questions about training on larger images / using the synthetic burst pipelines? I understand if would rather not!

Thank you again for sharing your work :)

Algolzw commented 1 year ago

Thank you!

May I email you with some questions about training on larger images / using the synthetic burst pipelines? I understand if would rather not!

Thank you again for sharing your work :)

No problem. You can email me if you have any questions.