lct1997 / BasicSTISR

BasicSTISR (Basic Scene Text Image Super Resolution) 是一个基于 PyTorch 的开源场景文本图像超分辨率工具箱.
Apache License 2.0
6 stars 0 forks source link

TypeError: cannot pickle 'Environment' object #1

Closed 2061360308 closed 1 month ago

2061360308 commented 1 month ago

environment

Problem description

I encountered the following problem

Traceback (most recent call last): 
File "...\BasicSTISR\main.py", line 45, in <module> main(config, args) 
File "...\BasicSTISR\main.py", line 14, in main Mission.train() 
File "...\BasicSTISR\interfaces\super_resolution.py", line 144, in train for j, data in (enumerate(train_loader)): 
File "...\.venv\lib\site-packages\torch\utils\data\dataloader.py", line 439, in iter return self._get_iterator() 
File "...\.venv\lib\site-packages\torch\utils\data\dataloader.py", line 387, in _get_iterator return _MultiProcessingDataLoaderIter(self) 
File "...\.venv\lib\site-packages\torch\utils\data\dataloader.py", line 1040, in init w.start() 
File "D:\program\python\Python3.10.9\lib\multiprocessing\process.py", line 121, in start self._popen = self._Popen(self) 
File "D:\program\python\Python3.10.9\lib\multiprocessing\context.py", line 224, in _Popen return _default_context.get_context().Process._Popen(process_obj) 
File "D:\program\python\Python3.10.9\lib\multiprocessing\context.py", line 336, in _Popen return Popen(process_obj) 
File "D:\program\python\Python3.10.9\lib\multiprocessing\popen_spawn_win32.py", line 93, in init reduction.dump(process_obj, to_child)
File "D:\program\python\Python3.10.9\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) 
TypeError: cannot pickle 'Environment' object</module>

With the help of Copilot, it suggested that I should open the LMDB environment in the getitem method instead of the init method. So I made some modifications to the content. (The complete code is at the end)

I want to know why I encountered this problem and what negative impact my solution may have.😊

# ./dataset/dataset.py  lmdbDataset_real

class lmdbDataset_real(Dataset):
    def __init__(
            self, root=None,
            voc_type='upper',
            max_len=100,
            test=False,
            cutblur=False,
            manmade_degrade=False,
            rotate=None
    ):
        super(lmdbDataset_real, self).__init__()
        self.root = root
        self.cb_flag = cutblur
        self.rotate = rotate
        self.voc_type = voc_type
        self.max_len = max_len
        self.test = test
        self.manmade_degrade = manmade_degrade

    def __len__(self):
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            with env.begin(write=False) as txn:
                nSamples = int(txn.get(b'num-samples'))
        return nSamples

    # ... keep other methods as they are ...

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            txn = env.begin(write=False)
            # ... rest of your code ...

complete code

# ./dataset/dataset.py  lmdbDataset_real

class lmdbDataset_real(Dataset):
    def __init__(
            self, root=None,
            voc_type='upper',
            max_len=100,
            test=False,
            cutblur=False,
            manmade_degrade=False,
            rotate=None
    ):
        super(lmdbDataset_real, self).__init__()
        self.root = root
        self.cb_flag = cutblur
        self.rotate = rotate
        self.voc_type = voc_type
        self.max_len = max_len
        self.test = test
        self.manmade_degrade = manmade_degrade

    def __len__(self):
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            with env.begin(write=False) as txn:
                nSamples = int(txn.get(b'num-samples'))
        return nSamples

    def rotate_img(self, image, angle):
        if not angle == 0.0:
            image = np.array(image)
            (h, w) = image.shape[:2]
            scale = 1.0
            # set the rotation center
            center = (w / 2, h / 2)
            # anti-clockwise angle in the function
            M = cv2.getRotationMatrix2D(center, angle, scale)
            image = cv2.warpAffine(image, M, (w, h))
            # back to PIL image
            image = Image.fromarray(image)

        return image

    def cutblur(self, img_hr, img_lr):
        p = random.random()

        img_hr_np = np.array(img_hr)
        img_lr_np = np.array(img_lr)

        randx = int(img_hr_np.shape[1] * (0.2 + 0.8 * random.random()))

        if p > 0.7:
            left_mix = random.random()
            if left_mix <= 0.5:
                img_lr_np[:, randx:] = img_hr_np[:, randx:]
            else:
                img_lr_np[:, :randx] = img_hr_np[:, :randx]

        return Image.fromarray(img_lr_np)

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            with env.begin(write=False) as txn:
                label_key = b'label-%09d' % index
                word = ""  # str(txn.get(label_key).decode())
                img_HR_key = b'image_hr-%09d' % index  # 128*32
                img_lr_key = b'image_lr-%09d' % index  # 64*16
                try:
                    img_HR = buf2PIL(txn, img_HR_key, 'RGB')
                    if self.manmade_degrade:
                        img_lr = degradation(img_HR)
                    else:
                        img_lr = buf2PIL(txn, img_lr_key, 'RGB')
                    # print("GOGOOGO..............", img_HR.size)
                    if self.cb_flag and not self.test:
                        img_lr = self.cutblur(img_HR, img_lr)

                    if not self.rotate is None:

                        if not self.test:
                            angle = random.random() * self.rotate * 2 - self.rotate
                        else:
                            angle = 0  # self.rotate

                        # img_HR = self.rotate_img(img_HR, angle)
                        # img_lr = self.rotate_img(img_lr, angle)

                    img_lr_np = np.array(img_lr).astype(np.uint8)
                    img_lry = cv2.cvtColor(img_lr_np, cv2.COLOR_RGB2YUV)
                    img_lry = Image.fromarray(img_lry)

                    img_HR_np = np.array(img_HR).astype(np.uint8)
                    img_HRy = cv2.cvtColor(img_HR_np, cv2.COLOR_RGB2YUV)
                    img_HRy = Image.fromarray(img_HRy)
                    word = txn.get(label_key)
                    if word is None:
                        print("None word:", label_key)
                        word = " "
                    else:
                        word = str(word.decode())
                    # print("img_HR:", img_HR.size, img_lr.size())

                except IOError or len(word) > self.max_len:
                    return self[index + 1]
                label_str = str_filt(word, self.voc_type)
                return img_HR, img_lr, img_HRy, img_lry, label_str
lct1997 commented 1 month ago

environment

  • Windows11
  • python3.10.9
  • torch 2.3.1

Problem description

I encountered the following problem

Traceback (most recent call last): 
File "...\BasicSTISR\main.py", line 45, in <module> main(config, args) 
File "...\BasicSTISR\main.py", line 14, in main Mission.train() 
File "...\BasicSTISR\interfaces\super_resolution.py", line 144, in train for j, data in (enumerate(train_loader)): 
File "...\.venv\lib\site-packages\torch\utils\data\dataloader.py", line 439, in iter return self._get_iterator() 
File "...\.venv\lib\site-packages\torch\utils\data\dataloader.py", line 387, in _get_iterator return _MultiProcessingDataLoaderIter(self) 
File "...\.venv\lib\site-packages\torch\utils\data\dataloader.py", line 1040, in init w.start() 
File "D:\program\python\Python3.10.9\lib\multiprocessing\process.py", line 121, in start self._popen = self._Popen(self) 
File "D:\program\python\Python3.10.9\lib\multiprocessing\context.py", line 224, in _Popen return _default_context.get_context().Process._Popen(process_obj) 
File "D:\program\python\Python3.10.9\lib\multiprocessing\context.py", line 336, in _Popen return Popen(process_obj) 
File "D:\program\python\Python3.10.9\lib\multiprocessing\popen_spawn_win32.py", line 93, in init reduction.dump(process_obj, to_child)
File "D:\program\python\Python3.10.9\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) 
TypeError: cannot pickle 'Environment' object</module>

With the help of Copilot, it suggested that I should open the LMDB environment in the getitem method instead of the init method. So I made some modifications to the content. (The complete code is at the end)

I want to know why I encountered this problem and what negative impact my solution may have.😊

# ./dataset/dataset.py  lmdbDataset_real

class lmdbDataset_real(Dataset):
    def __init__(
            self, root=None,
            voc_type='upper',
            max_len=100,
            test=False,
            cutblur=False,
            manmade_degrade=False,
            rotate=None
    ):
        super(lmdbDataset_real, self).__init__()
        self.root = root
        self.cb_flag = cutblur
        self.rotate = rotate
        self.voc_type = voc_type
        self.max_len = max_len
        self.test = test
        self.manmade_degrade = manmade_degrade

    def __len__(self):
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            with env.begin(write=False) as txn:
                nSamples = int(txn.get(b'num-samples'))
        return nSamples

    # ... keep other methods as they are ...

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            txn = env.begin(write=False)
            # ... rest of your code ...

complete code

# ./dataset/dataset.py  lmdbDataset_real

class lmdbDataset_real(Dataset):
    def __init__(
            self, root=None,
            voc_type='upper',
            max_len=100,
            test=False,
            cutblur=False,
            manmade_degrade=False,
            rotate=None
    ):
        super(lmdbDataset_real, self).__init__()
        self.root = root
        self.cb_flag = cutblur
        self.rotate = rotate
        self.voc_type = voc_type
        self.max_len = max_len
        self.test = test
        self.manmade_degrade = manmade_degrade

    def __len__(self):
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            with env.begin(write=False) as txn:
                nSamples = int(txn.get(b'num-samples'))
        return nSamples

    def rotate_img(self, image, angle):
        if not angle == 0.0:
            image = np.array(image)
            (h, w) = image.shape[:2]
            scale = 1.0
            # set the rotation center
            center = (w / 2, h / 2)
            # anti-clockwise angle in the function
            M = cv2.getRotationMatrix2D(center, angle, scale)
            image = cv2.warpAffine(image, M, (w, h))
            # back to PIL image
            image = Image.fromarray(image)

        return image

    def cutblur(self, img_hr, img_lr):
        p = random.random()

        img_hr_np = np.array(img_hr)
        img_lr_np = np.array(img_lr)

        randx = int(img_hr_np.shape[1] * (0.2 + 0.8 * random.random()))

        if p > 0.7:
            left_mix = random.random()
            if left_mix <= 0.5:
                img_lr_np[:, randx:] = img_hr_np[:, randx:]
            else:
                img_lr_np[:, :randx] = img_hr_np[:, :randx]

        return Image.fromarray(img_lr_np)

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with lmdb.open(self.root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) as env:
            with env.begin(write=False) as txn:
                label_key = b'label-%09d' % index
                word = ""  # str(txn.get(label_key).decode())
                img_HR_key = b'image_hr-%09d' % index  # 128*32
                img_lr_key = b'image_lr-%09d' % index  # 64*16
                try:
                    img_HR = buf2PIL(txn, img_HR_key, 'RGB')
                    if self.manmade_degrade:
                        img_lr = degradation(img_HR)
                    else:
                        img_lr = buf2PIL(txn, img_lr_key, 'RGB')
                    # print("GOGOOGO..............", img_HR.size)
                    if self.cb_flag and not self.test:
                        img_lr = self.cutblur(img_HR, img_lr)

                    if not self.rotate is None:

                        if not self.test:
                            angle = random.random() * self.rotate * 2 - self.rotate
                        else:
                            angle = 0  # self.rotate

                        # img_HR = self.rotate_img(img_HR, angle)
                        # img_lr = self.rotate_img(img_lr, angle)

                    img_lr_np = np.array(img_lr).astype(np.uint8)
                    img_lry = cv2.cvtColor(img_lr_np, cv2.COLOR_RGB2YUV)
                    img_lry = Image.fromarray(img_lry)

                    img_HR_np = np.array(img_HR).astype(np.uint8)
                    img_HRy = cv2.cvtColor(img_HR_np, cv2.COLOR_RGB2YUV)
                    img_HRy = Image.fromarray(img_HRy)
                    word = txn.get(label_key)
                    if word is None:
                        print("None word:", label_key)
                        word = " "
                    else:
                        word = str(word.decode())
                    # print("img_HR:", img_HR.size, img_lr.size())

                except IOError or len(word) > self.max_len:
                    return self[index + 1]
                label_str = str_filt(word, self.voc_type)
                return img_HR, img_lr, img_HRy, img_lry, label_str

If your environment is Windows11, please set the workers to 0 instead of 8. Linux users are the opposite. You can find this parameter in config/super_resolution.yaml(line 15).