tensorlayer / SRGAN

Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
https://github.com/tensorlayer/tensorlayerx
3.24k stars 813 forks source link

Error in custom dataset yk #255

Open zhenzi0322 opened 1 year ago

zhenzi0322 commented 1 year ago
Traceback (most recent call last):
  File "train.py", line 220, in <module>
    train()
  File "train.py", line 152, in train
    for step, (lr_patch, hr_patch) in enumerate(train_ds):
  File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 417, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 438, in _next_data
    data = self._dataset_fetcher.fetch(index)
  File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 347, in fetch
    data = [self.dataset[id] for id in batch_indices]
  File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/dataflow/utils.py", line 347, in <listcomp>
    data = [self.dataset[id] for id in batch_indices]
  File "train.py", line 56, in __getitem__
    lr_patch = self.lr_trans(hr_patch)
  File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/vision/transforms/transforms.py", line 274, in __call__
    return resize(image, self.size, self.interpolation)
  File "/usr/local/lib/python3.8/dist-packages/tensorlayerx/vision/transforms/functional.py", line 202, in resize
    output = cv2.resize(image, dsize=(size[1], size[0]), interpolation=_cv2_interp_from_str[method])
cv2.error: OpenCV(4.7.0) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
>  - src data type = 17 is not supported
>  - Expected Ptr<cv::UMat> for argument 'src'
AnsonCNS commented 9 months ago

You need to make sure your images format is supported by cv2 resize().

Converting all images to numpy uint8 solved the issue for me.

In getitem(), add the conversion line like this:

    def __getitem__(self, index):
        img = self.train_hr_imgs[index]
        img = img.astype(np.uint8) # add this
        hr_patch = self.hr_trans(img)
        lr_patch = self.lr_trans(hr_patch)
        return nor(lr_patch), nor(hr_patch)