Closed ArkashJ closed 1 month ago
Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!
Could you give an example?
Sure, so I'm working on a Dataset class which will be fed into a 3D UNET and my data is stored in a Zarr file. My class does some computation using the offsets and stores the data in an xarray. I need to perform a normalization and binarization of labels based on a threshold. If this was a pytorch dataset, I'd have a getItem function with a idx for which I could apply a preprocessing transform. However here, I'm not exactly sure what the best way to apply to transform would be because my data can go up to gigabytes (30+GB). not sure if using apply_ufunc would be useful here?
from typing import Tuple, Dict
from loguru import logger
from pathlib import Path
import zarr
import numpy as np
import json
import xarray as xr
from incasem_v2.utils.dtypes import verify_zattrs
class Cell3DDataset:
def __init__(
self,
path_to_zarr: Path,
):
self.path_to_zarr = path_to_zarr
self.zarr_store = (
zarr.open(str(self.path_to_zarr), mode="r")
if self._verify_zarr(self.path_to_zarr)
else None
)
self.label_offsets: Dict[str, Tuple[int, int, int]] = dict(
sorted(self._get_labels_offset().items())
)
self.raw_offsets: Dict[str, Tuple[int, int, int]] = dict(
sorted(self._get_raw_offset().items())
)
self.metric_mask_offsets: Dict[str, Tuple[int, int, int]] = dict(
sorted(self._get_metric_mask_offset().items())
)
self.dataset = self._create_dataset()
def _create_dataset(self) -> xr.Dataset:
ds = xr.Dataset()
return ds
@staticmethod
def _get_offset(
path_to_zattrs: Path,
) -> Tuple[int, int, int]:
try:
with open(path_to_zattrs, "r") as zattrs_file:
zattrs = json.load(zattrs_file)
offset = zattrs["offset"]
res = zattrs["resolution"]
offset_vx = (
int(offset[0] / res[0]),
int(offset[1] / res[1]),
int(offset[2] / res[2]),
)
logger.info(f"offset in voxels is {offset_vx}")
return offset_vx
except Exception as e:
logger.error(
"Could not find the zattrs at %s. Error in Cell3DDataset class %s"
% (path_to_zattrs, e)
)
raise e
@staticmethod
def _verify_zarr(path_to_zarr: Path) -> bool:
"""_summary_
Parameters
----------
path_to_zarr : Path
verify that the path to the zarr is a directory and has a volumes folder
Returns
-------
bool
True if the zarr is verified, False otherwise
Raises
------
e
"""
try:
if not path_to_zarr.is_dir():
logger.error("class is not a zarr!")
if not path_to_zarr.joinpath("volumes").is_dir():
logger.error("No volumes folder in the zarr!")
return True
except Exception as e:
logger.error(
"Could not find the zarr at %s. Error in Cell3DDataset class %s"
% (path_to_zarr, e)
)
raise e
def _get_labels_offset(self) -> Dict[str, Tuple[int, int, int]]:
"""_summary_
Returns
-------
Dict[str, Tuple[int, int, int]]
dictionary of the label offsets
Raises
------
FileNotFoundError
if the labels are not found, error out
e
_description_
"""
try:
label_offsets = {}
path_to_labels = self.path_to_zarr.joinpath("volumes/labels")
if not path_to_labels.exists():
raise FileNotFoundError(
"Could not find the labels at %s" % path_to_labels
)
for organelle_type in path_to_labels.iterdir():
logger.info(f"organelle_type: {organelle_type}")
if organelle_type.is_dir() and verify_zattrs(
str(organelle_type.joinpath(".zattrs"))
):
logger.info(f"path to organelle is {organelle_type}")
logger.info(
f".zattrs is {organelle_type.joinpath('.zattrs').exists()}"
)
label_offsets[f"organelle_{organelle_type.name}"] = (
self._get_offset(organelle_type.joinpath(".zattrs"))
)
return label_offsets
except Exception as e:
logger.error("Could not find the labels at %s" % e)
raise e
def _get_raw_offset(self) -> Dict[str, Tuple[int, int, int]]:
"""_summary_
Returns
-------
Dict[str, Tuple[int, int, int]]
dictionary of the raw offsets
Raises
------
FileNotFoundError
if the raw is not found
e
"""
try:
raw_offsets = {}
path_to_raw = self.path_to_zarr.joinpath("volumes/raw_equalized_0.02")
if not path_to_raw.exists():
raise FileNotFoundError("Could not find the raw at %s" % path_to_raw)
if verify_zattrs(str(path_to_raw.joinpath(".zattrs"))):
raw_offsets["raw"] = self._get_offset(path_to_raw.joinpath(".zattrs"))
return raw_offsets
else:
raise FileNotFoundError(
"Could not find the zattrs for the raw at %s"
% path_to_raw.joinpath(".zattrs")
)
except Exception as e:
logger.error("Could not find the raw at %s" % e)
raise e
def fix_offset_position(
self,
dims: Tuple[str, str, str] = ("z", "y", "x"),
) -> None:
if self.zarr_store:
raw_equalized = self.zarr_store["volumes/raw_equalized_0.02"][:]
self.dataset["volumes/raw_equalized_0.02"] = xr.DataArray(
raw_equalized, dims=dims
)
raw_equalized_offset = self.raw_offsets["raw"]
labels = np.zeros_like(raw_equalized)
for organelle_type, offset in self.label_offsets.items():
organelle_name = organelle_type.split("_")[-1]
label_values = self.zarr_store[f"volumes/labels/{organelle_name}"][:]
labels_offset = offset
z_start, z_end = (
labels_offset[0] - raw_equalized_offset[0],
labels_offset[0] + label_values.shape[0] - raw_equalized_offset[0], # type: ignore
)
y_start, y_end = (
labels_offset[1] - raw_equalized_offset[1],
labels_offset[1] + label_values.shape[1] - raw_equalized_offset[1], # type: ignore
)
x_start, x_end = (
labels_offset[2] - raw_equalized_offset[2],
labels_offset[2] + label_values.shape[2] - raw_equalized_offset[2], # type: ignore
)
labels[z_start:z_end, y_start:y_end, x_start:x_end] = label_values[:]
self.dataset[f"volumes/labels/{organelle_type}"] = xr.DataArray(
labels, dims=dims
)
Is the size an issue because it's out of memory, or do you just want something efficient?
For your normalization scheme, are you using scikit-learn preprocessors or rolling your own?
Correct me if I'm wrong, but the getItem function is lazily evaluated by the dataloader. I worry that apply my transformations to the xarray is inefficient compared to lazy evaluation. What would be the best way to apply my transformation pipeline? As for the normalization, I have my own function.
Boston University Class of 2024 MS in Computer Science (2022-2024) BA in Mathematics and Computer Science (2020-2024) https://www.arkashj.com/ +1 857-701-6117| linkedin.com/in/arkashj https://www.linkedin.com/in/arkashj | http://goog_2001913241 https://github.com/ArkashJ
On Fri, Oct 18, 2024 at 6:19 PM Thomas Martin @.***> wrote:
Is the size an issue because it's out of memory, or do you just want something efficient?
For your normalization scheme, are you using scikit-learn preprocessors or rolling your own?
— Reply to this email directly, view it on GitHub https://github.com/pydata/xarray/issues/9646#issuecomment-2423316948, or unsubscribe https://github.com/notifications/unsubscribe-auth/AUWI2YAL7HR2ZVUCIV5GJI3Z4GCPXAVCNFSM6AAAAABQGJCYLOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMRTGMYTMOJUHA . You are receiving this because you authored the thread.Message ID: @.***>
Can you give a minimal example, just a few lines, comparing what what you do you in pytorch from what's required in xarray?
Sure! Here's a dataset class I made where I setup the initial paths and did some compute and then in the getItem function, I'm doing a transform:
from PIL import Image
import torch
from torch.utils.data import Dataset
from typing import Callable, Dict, Any, Optional
from pathlib import Path
import numpy as np
from cell_interactome.data.utils import (
get_data_paths,
)
from cell_interactome.data.dinov2.transforms import make_normalize_transform
from torchvision.transforms import v2
class CellDataset2D(Dataset):
"""Implementation of a 2D dataset for cell images."""
def __init__(
self,
base_data_dir: Path,
path_to_txt: Path,
search_suffix: str,
max_files: int = -1,
transform: Optional[Callable] = None,
) -> None:
"""
Parameters
----------
base_data_dir: Path
The base directory where the data is stored. Has the following structure if ZFlattenMode is "MAX" or "MEAN" (otherwise there will be folders inside the frame folder for each z-stack labeled by their z-index):
Experiment_0
├── 488nm_CamB
│ └── frame_0
│ └── part_0.pth
├── 560_CamA
│ ├── frame_0
│ └── frame_1
└── 642_CamB
├── frame_0
└── frame_1
```
An .pth file should have the keys 'data' and 'metadata'.
.pth file structure:
data: imageData
values: np.ndarray
position: Tuple[int, int, int]
path_to_txt: Path
The path to the txt file containing the paths to the .pth files.
search_suffix: str
The suffix to search for in the data directory. Should include the final file extension (e.g. *.pth). Also, should NOT include leading and trailing /.
max_files: int = -1
The maximum number of files to load. If -1, all files are loaded.
transform: Optional[Callable]
A callable that takes in an image and returns a transformed image.
"""
self.base_data_dir = base_data_dir
self.path_to_txt = path_to_txt
self.data_paths = get_data_paths(
base_data_dir=base_data_dir,
path_to_txt=path_to_txt,
search_suffix=search_suffix,
max_files=max_files,
)
self.transform = transform
self.max_files = max_files
def __len__(self) -> int:
return len(self.data_paths)
def __getitem__(self, index: int) -> Dict[str, Any]:
pth_path = self.data_paths[index]
image2d = torch.load(pth_path, weights_only=False)
data = image2d["data"]
metadata = image2d["metadata"]
path = metadata["path"]
frame = metadata["frame"]
wavelength = metadata["wavelength"]
data["position"] = torch.from_numpy(data["position"])
image = np.array(data["values"], dtype=np.uint8)
image = Image.fromarray(image)
# TODO: separate train and test transform logic
if self.transform:
output = self.transform(image)
else:
output = image
metadata = {
"pth_path": str(pth_path),
"tif_path": str(path),
"frame": frame,
"wavelength": wavelength,
}
if not isinstance(output, dict):
output = {"image": output}
item = {
**output,
"metadata": metadata,
}
return item
An example of a transform would be:
```python
class ToTensor(v2.Transform):
"""
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
"""
def __init__(self) -> None:
super().__init__()
self.transform = v2.Compose(
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
)
def __call__(self, pic: Union[Image.Image, np.ndarray]) -> torch.Tensor:
return self.transform(pic)
Now I'm confused as to how I can do so with an xarray dataset efficiently. Because with a pytorch dataset, I shall feed it into a dataloader and pass a collate function in there. I was planning on using xbatcher however I'd love to direction on how to tackle this problem
We would love to have better integration with PyTorch, but as someone who has personally never used PyTorch I'm afraid this is still quite a lot in one go for me to easily follow 😅
(also I added syntax highlighting to your code examples for you)
I was planning on using xbatcher
Summoning @maxrjones 🪄
Thanks for the syntax highlight.
[For context, we're a computational biology lab at Harvard working with cell organelles. I was originally going to use pytorch for the entire pipeline however a lot of our tools used zarr files so I decided to switch to xarray. I'd be happy to talk about certain xarray improvements that can build pipelines. In my opinion xarray is a very interesting tool for zarr processing given the fact that a lot of natural science data is in zarrs]
I found a repo from a startup and their pipeline involved Zarr -> Xarray -> Xbatcher.[ https://github.com/earth-mover/dataloader-demo/blob/main/main.py].
I wonder if this is the solution to my problem
You can inherit from a Dataset
and implement your own __getitem__
, which can do transformations, just like pytorch. I wouldn't recommend it unless you're an advanced user.
I think the best path forward to attempt to build the data processing using idiomatic xarray, and then use that experience to reflect on the differences between xarray & pytorch. I don't think it's that productive to try and force xarray into idiomatic pytorch (though I'm a fan of pytorch, have used it a lot!).
If you would like more direction, perhaps open a discussion with small examples of what you've tried so far and what you've found difficult.
I'll close this for now, thanks.
What is your issue?
In Pytorch datasets, we can do transform operations using the getItem function given an index. Is there a way to do so in Xarray as well or is it recommend that we convert the xarray to batches using xbatcher, convert that into a pytorch data loader and work on the data loader?