open-mmlab / mmcv

OpenMMLab Computer Vision Foundation
https://mmcv.readthedocs.io/en/latest/
Apache License 2.0
5.92k stars 1.65k forks source link

[Feature] Proposal: Refactor image loading transforms for better modularity and extensibility #3198

Open shenshanf opened 3 weeks ago

shenshanf commented 3 weeks ago

What is the feature?

Current Issues:

  1. The LoadImageFromFile class has hardcoded key mappings, making it difficult to extend for different use cases like stereo or multi-view scenarios.
  2. The core image loading logic is duplicated when creating new loading transforms.
  3. The current inheritance-based approach for creating new loading transforms (like stereo) is not flexible enough for dynamic scenarios (e.g., varying number of views in MVS).

Proposed Solution: Extract the core image loading logic into a separate ImageLoader class and refactor the transforms to use this common functionality. Here's a basic implementation proposal:


class ImageLoader:
    """Core functionality for loading images."""
    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        self.to_float32 = to_float32
        self.color_type = color_type
        self.imdecode_backend = imdecode_backend
        self.ignore_empty = ignore_empty
        self.backend_args = backend_args

    def load(self, filepath: str) -> Optional[np.ndarray]:
        """Load single image from file."""
        try:
            img_bytes = fileio.get(filepath, backend_args=self.backend_args)
            img = mmcv.imfrombytes(
                img_bytes, 
                flag=self.color_type,
                backend=self.imdecode_backend
            )
            if self.to_float32:
                img = img.astype(np.float32)
            return img
        except Exception as e:
            if self.ignore_empty:
                return None
            raise e

### Any other context?

_No response_
shenshanf commented 3 weeks ago
@TRANSFORMS.register_module()
class LoadImageFromFile(BaseTransform):
    """Load single image from file."""

    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        super().__init__()
        self.loader = ImageLoader(
            to_float32=to_float32,
            color_type=color_type,
            imdecode_backend=imdecode_backend,
            ignore_empty=ignore_empty,
            backend_args=backend_args
        )

    def transform(self, results: dict) -> Optional[dict]:
        img = self.loader.load(results['img_path'])
        if img is None:
            return None

        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
        return results

@TRANSFORMS.register_module()
class LoadMultiViewImage(BaseTransform):
    """Load multiple view images."""

    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        super().__init__()
        self.loader = ImageLoader(
            to_float32=to_float32,
            color_type=color_type,
            imdecode_backend=imdecode_backend,
            ignore_empty=ignore_empty,
            backend_args=backend_args
        )

    def transform(self, results: dict) -> Optional[dict]:
        # 加载多视角图像
        imgs = []
        shapes = []

        for filepath in results['img_paths']:
            img = self.loader.load(filepath)
            if img is None:
                return None

            imgs.append(img)
            shapes.append(img.shape[:2])

        results['imgs'] = imgs
        results['img_shapes'] = shapes
        results['ori_shapes'] = shapes.copy()
        return results
shenshanf commented 3 weeks ago
@TRANSFORMS.register_module()
class LoadStereoImage(BaseTransform):
    """Load stereo images."""

    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        super().__init__()
        self.loader = ImageLoader(
            to_float32=to_float32,
            color_type=color_type,
            imdecode_backend=imdecode_backend,
            ignore_empty=ignore_empty,
            backend_args=backend_args
        )

    def transform(self, results: dict) -> Optional[dict]:
        # 加载左图
        left_img = self.loader.load(results['left_img_path'])
        if left_img is None:
            return None

        # 加载右图
        right_img = self.loader.load(results['right_img_path'])
        if right_img is None:
            return None

        # 更新结果
        results.update({
            'left_img': left_img,
            'right_img': right_img,
            'left_img_shape': left_img.shape[:2],
            'right_img_shape': right_img.shape[:2],
            'ori_left_shape': left_img.shape[:2],
            'ori_right_shape': right_img.shape[:2]
        })
        return results