pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

Using Xarray Dataset for transform operations similar to Pytorch Datasets #9646

Closed ArkashJ closed 1 month ago

ArkashJ commented 1 month ago

What is your issue?

In Pytorch datasets, we can do transform operations using the getItem function given an index. Is there a way to do so in Xarray as well or is it recommend that we convert the xarray to batches using xbatcher, convert that into a pytorch data loader and work on the data loader?

welcome[bot] commented 1 month ago

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!

max-sixty commented 1 month ago

Could you give an example?

ArkashJ commented 1 month ago

Sure, so I'm working on a Dataset class which will be fed into a 3D UNET and my data is stored in a Zarr file. My class does some computation using the offsets and stores the data in an xarray. I need to perform a normalization and binarization of labels based on a threshold. If this was a pytorch dataset, I'd have a getItem function with a idx for which I could apply a preprocessing transform. However here, I'm not exactly sure what the best way to apply to transform would be because my data can go up to gigabytes (30+GB). not sure if using apply_ufunc would be useful here?


from typing import Tuple, Dict
from loguru import logger
from pathlib import Path
import zarr
import numpy as np
import json
import xarray as xr
from incasem_v2.utils.dtypes import verify_zattrs

class Cell3DDataset:
    def __init__(
        self,
        path_to_zarr: Path,
    ):
        self.path_to_zarr = path_to_zarr

        self.zarr_store = (
            zarr.open(str(self.path_to_zarr), mode="r")
            if self._verify_zarr(self.path_to_zarr)
            else None
        )
        self.label_offsets: Dict[str, Tuple[int, int, int]] = dict(
            sorted(self._get_labels_offset().items())
        )
        self.raw_offsets: Dict[str, Tuple[int, int, int]] = dict(
            sorted(self._get_raw_offset().items())
        )
        self.metric_mask_offsets: Dict[str, Tuple[int, int, int]] = dict(
            sorted(self._get_metric_mask_offset().items())
        )
        self.dataset = self._create_dataset()

    def _create_dataset(self) -> xr.Dataset: 
        ds = xr.Dataset()
        return ds

    @staticmethod
    def _get_offset(
        path_to_zattrs: Path,
    ) -> Tuple[int, int, int]:
        try:
            with open(path_to_zattrs, "r") as zattrs_file:
                zattrs = json.load(zattrs_file)
                offset = zattrs["offset"]
                res = zattrs["resolution"]
                offset_vx = (
                    int(offset[0] / res[0]),
                    int(offset[1] / res[1]),
                    int(offset[2] / res[2]),
                )
                logger.info(f"offset in voxels is {offset_vx}")
                return offset_vx
        except Exception as e:
            logger.error(
                "Could not find the zattrs at %s. Error in Cell3DDataset class %s"
                % (path_to_zattrs, e)
            )
            raise e

    @staticmethod
    def _verify_zarr(path_to_zarr: Path) -> bool:
        """_summary_

        Parameters
        ----------
        path_to_zarr : Path
            verify that the path to the zarr is a directory and has a volumes folder
        Returns
        -------
        bool
            True if the zarr is verified, False otherwise
        Raises
        ------
        e
        """
        try:
            if not path_to_zarr.is_dir():
                logger.error("class is not a zarr!")
            if not path_to_zarr.joinpath("volumes").is_dir():
                logger.error("No volumes folder in the zarr!")
            return True
        except Exception as e:
            logger.error(
                "Could not find the zarr at %s. Error in Cell3DDataset class %s"
                % (path_to_zarr, e)
            )
            raise e

    def _get_labels_offset(self) -> Dict[str, Tuple[int, int, int]]:
        """_summary_

        Returns
        -------
        Dict[str, Tuple[int, int, int]]
            dictionary of the label offsets
        Raises
        ------
        FileNotFoundError
            if the labels are not found, error out
        e
            _description_
        """
        try:
            label_offsets = {}
            path_to_labels = self.path_to_zarr.joinpath("volumes/labels")

            if not path_to_labels.exists():
                raise FileNotFoundError(
                    "Could not find the labels at %s" % path_to_labels
                )
            for organelle_type in path_to_labels.iterdir():
                logger.info(f"organelle_type: {organelle_type}")
                if organelle_type.is_dir() and verify_zattrs(
                    str(organelle_type.joinpath(".zattrs"))
                ):
                    logger.info(f"path to organelle is {organelle_type}")
                    logger.info(
                        f".zattrs is {organelle_type.joinpath('.zattrs').exists()}"
                    )
                    label_offsets[f"organelle_{organelle_type.name}"] = (
                        self._get_offset(organelle_type.joinpath(".zattrs"))
                    )

            return label_offsets
        except Exception as e:
            logger.error("Could not find the labels at %s" % e)
            raise e

    def _get_raw_offset(self) -> Dict[str, Tuple[int, int, int]]:
        """_summary_

        Returns
        -------
        Dict[str, Tuple[int, int, int]]
            dictionary of the raw offsets
        Raises
        ------
        FileNotFoundError
            if the raw is not found
        e
        """
        try:
            raw_offsets = {}
            path_to_raw = self.path_to_zarr.joinpath("volumes/raw_equalized_0.02")
            if not path_to_raw.exists():
                raise FileNotFoundError("Could not find the raw at %s" % path_to_raw)

            if verify_zattrs(str(path_to_raw.joinpath(".zattrs"))):
                raw_offsets["raw"] = self._get_offset(path_to_raw.joinpath(".zattrs"))
                return raw_offsets
            else:
                raise FileNotFoundError(
                    "Could not find the zattrs for the raw at %s"
                    % path_to_raw.joinpath(".zattrs")
                )
        except Exception as e:
            logger.error("Could not find the raw at %s" % e)
            raise e

    def fix_offset_position(
        self,
        dims: Tuple[str, str, str] = ("z", "y", "x"),
    ) -> None: 
        if self.zarr_store:
            raw_equalized = self.zarr_store["volumes/raw_equalized_0.02"][:]
            self.dataset["volumes/raw_equalized_0.02"] = xr.DataArray(
                raw_equalized, dims=dims
            )
            raw_equalized_offset = self.raw_offsets["raw"]
            labels = np.zeros_like(raw_equalized)
            for organelle_type, offset in self.label_offsets.items():
                organelle_name = organelle_type.split("_")[-1]
                label_values = self.zarr_store[f"volumes/labels/{organelle_name}"][:]

                labels_offset = offset
                z_start, z_end = (
                    labels_offset[0] - raw_equalized_offset[0],
                    labels_offset[0] + label_values.shape[0] - raw_equalized_offset[0],  # type: ignore
                )
                y_start, y_end = (
                    labels_offset[1] - raw_equalized_offset[1],
                    labels_offset[1] + label_values.shape[1] - raw_equalized_offset[1],  # type: ignore
                )
                x_start, x_end = (
                    labels_offset[2] - raw_equalized_offset[2],
                    labels_offset[2] + label_values.shape[2] - raw_equalized_offset[2],  # type: ignore
                )

                labels[z_start:z_end, y_start:y_end, x_start:x_end] = label_values[:]
                self.dataset[f"volumes/labels/{organelle_type}"] = xr.DataArray(
                    labels, dims=dims
                )
ThomasMGeo commented 1 month ago

Is the size an issue because it's out of memory, or do you just want something efficient?

For your normalization scheme, are you using scikit-learn preprocessors or rolling your own?

ArkashJ commented 1 month ago

Correct me if I'm wrong, but the getItem function is lazily evaluated by the dataloader. I worry that apply my transformations to the xarray is inefficient compared to lazy evaluation. What would be the best way to apply my transformation pipeline? As for the normalization, I have my own function.

Boston University Class of 2024 MS in Computer Science (2022-2024) BA in Mathematics and Computer Science (2020-2024) https://www.arkashj.com/ +1 857-701-6117| linkedin.com/in/arkashj https://www.linkedin.com/in/arkashj | http://goog_2001913241 https://github.com/ArkashJ

On Fri, Oct 18, 2024 at 6:19 PM Thomas Martin @.***> wrote:

Is the size an issue because it's out of memory, or do you just want something efficient?

For your normalization scheme, are you using scikit-learn preprocessors or rolling your own?

— Reply to this email directly, view it on GitHub https://github.com/pydata/xarray/issues/9646#issuecomment-2423316948, or unsubscribe https://github.com/notifications/unsubscribe-auth/AUWI2YAL7HR2ZVUCIV5GJI3Z4GCPXAVCNFSM6AAAAABQGJCYLOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMRTGMYTMOJUHA . You are receiving this because you authored the thread.Message ID: @.***>

max-sixty commented 1 month ago

Can you give a minimal example, just a few lines, comparing what what you do you in pytorch from what's required in xarray?

ArkashJ commented 1 month ago

Sure! Here's a dataset class I made where I setup the initial paths and did some compute and then in the getItem function, I'm doing a transform:

from PIL import Image
import torch
from torch.utils.data import Dataset
from typing import Callable, Dict, Any, Optional
from pathlib import Path
import numpy as np
from cell_interactome.data.utils import (
    get_data_paths,
)
from cell_interactome.data.dinov2.transforms import make_normalize_transform
from torchvision.transforms import v2

class CellDataset2D(Dataset):
    """Implementation of a 2D dataset for cell images."""

    def __init__(
        self,
        base_data_dir: Path,
        path_to_txt: Path,
        search_suffix: str,
        max_files: int = -1,
        transform: Optional[Callable] = None,
    ) -> None:
        """
        Parameters
        ----------
        base_data_dir: Path
            The base directory where the data is stored. Has the following structure if ZFlattenMode is "MAX" or "MEAN" (otherwise there will be folders inside the frame folder for each z-stack labeled by their z-index):
        Experiment_0
        ├── 488nm_CamB
        │   └── frame_0
        │       └── part_0.pth
        ├── 560_CamA
        │   ├── frame_0
        │   └── frame_1
        └── 642_CamB
            ├── frame_0
            └── frame_1
        ```

        An .pth file should have the keys 'data' and 'metadata'.

        .pth file structure:
        data: imageData
            values: np.ndarray
            position: Tuple[int, int, int]

    path_to_txt: Path
        The path to the txt file containing the paths to the .pth files.

    search_suffix: str
        The suffix to search for in the data directory. Should include the final file extension (e.g. *.pth). Also, should NOT include leading and trailing /.

    max_files: int = -1
        The maximum number of files to load. If -1, all files are loaded.

    transform: Optional[Callable]
        A callable that takes in an image and returns a transformed image.

    """
    self.base_data_dir = base_data_dir
    self.path_to_txt = path_to_txt
    self.data_paths = get_data_paths(
        base_data_dir=base_data_dir,
        path_to_txt=path_to_txt,
        search_suffix=search_suffix,
        max_files=max_files,
    )
    self.transform = transform

    self.max_files = max_files

def __len__(self) -> int:
    return len(self.data_paths)

def __getitem__(self, index: int) -> Dict[str, Any]:
    pth_path = self.data_paths[index]
    image2d = torch.load(pth_path, weights_only=False)
    data = image2d["data"]
    metadata = image2d["metadata"]
    path = metadata["path"]
    frame = metadata["frame"]
    wavelength = metadata["wavelength"]
    data["position"] = torch.from_numpy(data["position"])

    image = np.array(data["values"], dtype=np.uint8)
    image = Image.fromarray(image)
    # TODO: separate train and test transform logic
    if self.transform:
        output = self.transform(image)
    else:
        output = image

    metadata = {
        "pth_path": str(pth_path),
        "tif_path": str(path),
        "frame": frame,
        "wavelength": wavelength,
    }

    if not isinstance(output, dict):
        output = {"image": output}

    item = {
        **output,
        "metadata": metadata,
    }
    return item        

An example of a transform would be:
```python

class ToTensor(v2.Transform):
    """
    Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    """

    def __init__(self) -> None:
        super().__init__()
        self.transform = v2.Compose(
            [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
        )

    def __call__(self, pic: Union[Image.Image, np.ndarray]) -> torch.Tensor:
        return self.transform(pic)

Now I'm confused as to how I can do so with an xarray dataset efficiently. Because with a pytorch dataset, I shall feed it into a dataloader and pass a collate function in there. I was planning on using xbatcher however I'd love to direction on how to tackle this problem

TomNicholas commented 1 month ago

We would love to have better integration with PyTorch, but as someone who has personally never used PyTorch I'm afraid this is still quite a lot in one go for me to easily follow 😅

(also I added syntax highlighting to your code examples for you)

I was planning on using xbatcher

Summoning @maxrjones 🪄

ArkashJ commented 1 month ago

Thanks for the syntax highlight.

[For context, we're a computational biology lab at Harvard working with cell organelles. I was originally going to use pytorch for the entire pipeline however a lot of our tools used zarr files so I decided to switch to xarray. I'd be happy to talk about certain xarray improvements that can build pipelines. In my opinion xarray is a very interesting tool for zarr processing given the fact that a lot of natural science data is in zarrs]

I found a repo from a startup and their pipeline involved Zarr -> Xarray -> Xbatcher.[ https://github.com/earth-mover/dataloader-demo/blob/main/main.py].

I wonder if this is the solution to my problem

max-sixty commented 1 month ago

You can inherit from a Dataset and implement your own __getitem__, which can do transformations, just like pytorch. I wouldn't recommend it unless you're an advanced user.

I think the best path forward to attempt to build the data processing using idiomatic xarray, and then use that experience to reflect on the differences between xarray & pytorch. I don't think it's that productive to try and force xarray into idiomatic pytorch (though I'm a fan of pytorch, have used it a lot!).

If you would like more direction, perhaps open a discussion with small examples of what you've tried so far and what you've found difficult.

I'll close this for now, thanks.