Machine-Learning-Dynamical-Systems / kooplearn

A Python package to learn the Koopman operator.
https://kooplearn.readthedocs.io
MIT License
42 stars 9 forks source link

Unnecesary Memory usage of the `TrajectoryContextDataset` #13

Open Danfoa opened 3 months ago

Danfoa commented 3 months ago

Hi @pietronvll,

There seems to be a very bad memory management within the TrajectoryContextDataset class.

When I do:

train_trajs = np.float32(train_trajs)
val_trajs = np.float32(val_trajs)
test_trajs = np.float32(test_trajs)

log.info(f"Training dataset memory footprint {train_trajs.nbytes / 1e6:.2f} [MB]")
# Split data into context windows using kooplean. Load data to device if available, and store reference of the
# raw dataset to do required regression tasks.
self.train_dataset = multi_traj_to_context(train_trajs,
                                           context_window_len=self.pred_horizon + self.lookback_len,
                                           backend="torch",
                                           device=self.device)

log.info(f"Torch dataset memory footprint {self.train_dataset.data.element_size() * self.train_dataset.data.nelement() / 1e6:.2f} [MB]")

We go from 93MB to 4GB of memory consumption, which becomes quite problematic with the default behaviour of this class which loads the data tensor to GPU for fast training.

[2024-06-29 13:15:48,471][data.DynamicsDataModule][INFO] - Training dataset memory footprint 93.00 [MB]
[2024-06-29 13:15:48,564][data.DynamicsDataModule][INFO] - Torch dataset memory footprint 4404.44 [MB]

Will update soon on details.

Danfoa commented 3 months ago

Since the hierarchy of classes of Datasets of ContextWindow is highly complex and difficult to edit/modify/justify. Here I propose an alternative setup which:

  1. Does not use any class hierarchy. That is the only class we inherit from is the abstract class torch.utils.data.Dataset.
  2. Handles both numpy and torch backends.
  3. Mantains the memory footprint from the original trajectories arrays/data
  4. Reduce the number of code lines used.

I prefer to do this since the current data managing pipeline's complex hierarchical structure and length are overly complex and complicated to understand and modify. As discussed also with @prolearner, we belive it is best to keep the OOP to a bare minimum and provide a reduced number of concise, well-documented, and simple classes/functions.

class TrajectoryContextDataset(torch.utils.data.Dataset):
    """Class for a collection of context windows with tensor features."""

    def __init__(
            self,
            trajectories: list[ArrayLike, ...],
            context_length: int = 2,
            time_lag: int = 1,
            backend: str = "numpy",
            shuffle: bool = False,
            seed: int = 1234,
            **backend_kw,
            ):
        """ Initialize a Dataset instance that can be passed to a torch.data.DataLoader.

        Args:
            trajectories: (list[ArrayLike, ...]) A list of trajectories (or potentially different time-length) of
                shape (time, *features).
            context_length: (int) Number of time-frames per context window. Default to 2.
            time_lag: (int) Time lag between successive context windows. Default to 1.
            backend: (str) Specifies the backend to be used (``'numpy'``, ``'torch'``). Default to ``'numpy'``.
            shuffle: (bool) If True, shuffles the context windows. Default to False.
            seed: (int) Seed for the random number generator. Default to 1234.
            **backend_kw: (dict) Keyword arguments to pass to the backend.
                If backend='torch', for instance it is possible to specify the device and type of the data samples.
                If backend='numpy', it is possible to specify the dtype of the data samples
        """
        if context_length < 1 and not isinstance(context_length, int):
            raise ValueError(f"context_length must be an interger >= 1, got {context_length}")

        if time_lag < 1:
            raise ValueError(f"time_lag must be >= 1, got {time_lag}")

        if isinstance(trajectories, list):
            raise ValueError(f"Expected list of trajectories of shape (time, *features), got {type(trajectories)}.")

        torch, backend = parse_backend(backend)

        self._backend = backend
        self._context_length = context_length
        self._time_lag = time_lag
        self._indices = []
        self._raw_data = []  # Variable containing the trajectories in the desired backed.

        # Convert trajectories to the desired backend. We copy data only once, and keep the original memory footprint.
        if backend == "numpy":  # If backend is numpy, we convert the data to numpy.
            self._raw_data = [np.array(traj, **backend_kw) for traj in trajectories]
        elif backend == "torch":  # Load raw data ONCE to GPU if specified in backend_kw.
            self._raw_data = [torch.tensor(traj, **backend_kw) for traj in trajectories]

        # Compute the list of indices (traj_idx, slice(start, end)) for each ContextWindow.
        for traj_idx, traj_data in enumerate(self._raw_data):
            if traj_data.ndim < 2:
                raise ShapeError(
                    f"Shape of trajectory {traj_idx} is {traj_data.shape}. Expected a 2D array of (time, *features)."
                    )
            context_window_slices = _slices_from_traj_len(time_horizon=traj_data.shape[0],
                                                          context_length=context_length,
                                                          time_lag=time_lag)
            # Store a tuple of (traj_idx, context window slice) for each context window.
            self._indices.extend([(traj_idx, s) for s in context_window_slices])

        self._memory_footprint = None
        self._shuffled = False

        if shuffle:
            self.shuffle(seed=seed)
        self.shuffle()

        log.info(f"TrajectoryContextDataset initialized with {len(self)} context windows.")

    def shuffle(self, seed: int = None):
        """Shuffles the context windows."""
        if seed is not None:
            np.random.seed(seed)
        np.random.shuffle(self._indices)
        self._shuffled = True

    @property
    def backend(self):
        return str(self._backend)

    @property
    def context_length(self):
        return int(self._context_length)

    @property
    def time_lag(self):
        return int(self._time_lag)

    @property
    def is_shuffled(self):
        return self._shuffled

    @property
    def memory_footprint(self):
        """Returns the memory footprint of the dataset in bytes."""
        if self._memory_footprint is None:
            if self._backend == "numpy":
                self._memory_footprint = sum(traj.nbytes for traj in self._raw_data)
            elif self._backend == "torch":
                self._memory_footprint = sum(traj.element_size() * traj.nelement() for traj in self._raw_data)
        return self._memory_footprint

    def __len__(self):
        return len(self._indices)

    def __getitem__(self, idx):
        traj_idx, slice_idx = self._indices[idx]
        sample = self._raw_data[traj_idx][slice_idx]
        return sample

    def __repr__(self):
        device = "cpu"
        if self._backend == "torch":
            if len(self._raw_data) > 0:
                device = self._raw_data[0].device
        return f"Memory use: {self.memory_footprint / 1e6:.2f} MB on {device}"

def _slices_from_traj_len(time_horizon: int, context_length: int, time_lag: int) -> list[slice]:
    """ Returns the list of slices (start_time_idx, end_time_idx) for each context window in the trajectory.
    Args:
        time_horizon: (int) Number time-frames of the trajectory.
        context_length: (int) Number of time-frames per context window
        time_lag: (int) Time lag between successive context windows.
    Returns:
        list[slice]: List of slices for each context window.

    Examples
    --------
    >>> time_horizon, context_length, time_lag = 10, 4, 2
    >>> slices = _slices_from_traj_len(time_horizon, context_length, time_lag)
    >>> for s in slices:
    ...     print(f"start: {s.start}, end: {s.stop}")
    start: 0, end: 4
    start: 2, end: 6
    start: 4, end: 8
    start: 6, end: 10
    """
    slices = []
    for start in range(0, time_horizon - context_length + 1, time_lag):
        end = start + context_length
        slices.append(slice(start, end))

    return slices

def traj_to_contexts(
        trajectory: np.ndarray,
        context_window_len: int = 2,
        time_lag: int = 1,
        backend: str = "auto",
        **backend_kwargs,
        ):
    """Transforms a single trajectory to a sequence of context windows.

    Args:
    ----
        trajectory (np.ndarray): A trajectory of shape ``(n_frames, *features_shape)``.
        context_window_len (int, optional): Length of the context window. Default to ``2``.
        time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``.
        backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``,
        will use the same backend of the trajectory. Default to ``'auto'``.
        backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``,
        it is possible to specify the device of the tensor.

    Returns:
    -------
        TrajectoryContextDataset: A sequence of context windows.
    """
    return TrajectoryContextDataset(
        trajectories=[trajectory],
        context_length=context_window_len,
        time_lag=time_lag,
        backend=backend,
        **backend_kwargs,
        )

def multi_traj_to_context(
        trajectories: list[ArrayLike, ...],
        context_window_len: int = 2,
        time_lag: int = 1,
        backend: str = "auto",
        **backend_kwargs,
        ):
    """Transforms a collection of trajectories to a sequence of context windows.

    Args:
    ----
        trajectories (np.ndarray): A trajectory of shape ``(n_trajs, n_frames, *features_shape)``.
        context_window_len (int, optional): Length of the context window. Default to ``2``.
        time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``.
        backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``,
        will use the same backend of the trajectory. Default to ``'auto'``.
        backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``,
        it is possible to specify the device of the tensor.

    Returns:
    -------
        TrajectoryContextDataset: A sequence of context windows.
    """
    return TrajectoryContextDataset(
        trajectories=trajectories,
        context_length=context_window_len,
        time_lag=time_lag,
        backend=backend,
        **backend_kwargs,
        )