NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
4.98k stars 609 forks source link

Error when implementing probabilistic augmentations #5544

Open ujjwalnur opened 2 days ago

ujjwalnur commented 2 days ago

Describe the question.

I have tested with DALI 1.38 and DALI 1.39 on CUDA 12.2

I am trying to specify configurations for various augmentation operations in a hydra configuration file and then I want to construct those augmentations inside a pipeline. For instance, following is an example of a configuration file for random_horizontal_flip

Screenshot 2024-06-30 221624

The corresponding implementation is as below :

Screenshot 2024-06-30 221733

In a similar fashion I have configurations for other augmentations as well.

In my main code, I construct a pipeline as follows ( explanation follows )

Screenshot 2024-06-30 221952

So, as you can see following is the sequence of operations

  1. You pass a set of hydra configuration files to a class MyDataset .
  2. The corresponding functions are then built as partial functions using hydra.utils.instantiate
  3. Then they are used for augmentation inside the pipeline using self._transform_data(....)

I have chosen this mechanism as it allows me to choose and modify all aspects of data augmentation using Hydra.

However, when running I often get error messages which tend to suggest that something is wrong with dali_fn.random.uniform(...) because whenever I use two augmentations ( e.g :- resize_exact and horizontal_flip ) it will always give me weird errors which will disappear when :

a) Either I don't use any random.uniform function

b) Or I use only one augmentation

what is going wrong in my implementation ? Any hints ? Shouldn't NVIDIA DALI consider random.uniform nodes defined in different functions as different nodes ?

Check for duplicates

JanuszL commented 2 days ago

Hi @ujjwalnur,

Thank you for reaching out. I think it would be best to provide a code snipped we could run on our end that reproduces the problem you observe. Otherwise, it is difficult to find the source of the issue.

ujjwalnur commented 2 days ago

Hi @JanuszL ,

Thanks for the response. Since it is a TFRecord based code, it would be futile to paste it here directly. Let me write a small snippet with non-TFRecord dataset and I will paste it here soon so you can play with it

JanuszL commented 2 days ago

Hi @JanuszL ,

Thanks for the response. Since it is a TFRecord based code, it would be futile to paste it here directly. Let me write a small snippet with non-TFRecord dataset and I will paste it here soon so you can play with it

Yes, anything that reproduces the problem, like a toy example should do.

ujjwalnur commented 2 days ago

Since I cannot attach *.py files here, I am providing the code snippet you can execute directly. It is a bit long because I wanted it to mimic the functionality I am experimenting with. I have tried to document it as well as I could .

Steps to run the code :-

  1. Take the following python script and put it in a folder.
  2. Take three YAML files and put them in a folder structure as shown below using the screenshots.
  3. Run the python script. If you name the folders differently you might have to set the config_path in the @hydra.main decorator in the code.

Screenshot 2024-07-01 114407

Now I show you the contents of each YAML file

config.yaml

Screenshot 2024-07-01 114510

horizontal_flip.yaml

Screenshot 2024-07-01 114552

resize_exact.yaml

Screenshot 2024-07-01 114617

Main script to run

from functools import reduce
from typing import Optional, List, Dict
import tensorflow as tf
import cv2
import numpy as np
import nvidia.dali.fn as dali_fn
import nvidia.dali.types as dali_types
from hydra.utils import instantiate
from omegaconf import DictConfig
from typing import Callable, Sequence
from nvidia.dali.types import DALIDataType
from nvidia.dali.data_node import DataNode
from nvidia import dali
from omegaconf import OmegaConf
import nvidia.dali.plugin.tf as dali_tf
import hydra

TensorList = dali.tensors.TensorListCPU | dali.tensors.TensorListGPU

def compose(*functions):
    """
    Args:
        *functions: The functions that should be composed together.

    Returns:
        The composed function that applies the given functions in order.

    """

    def compose_two(f, g):
        return lambda x: f(g(x))

    return reduce(compose_two, functions, lambda x: x)

def horizontal_flip(
    inputs: List[TensorList],
    img_indices: List[int],
    box_indices: Optional[List[int]],
    mask_indices: Optional[List[int]],
    probability: float = 0.5,
    img_flip_kwargs: Optional[Dict] = None,
    box_flip_kwargs: Optional[Dict] = None,
):
    """
    Args:
        inputs: A list of tensors representing the input data.
        img_indices: A list of integers representing the indices of input tensors which are images.
        box_indices: (Optional) A list of integers representing the indices of input tensors which are bounding boxes. Defaults to an empty list.
        mask_indices: (Optional) A list of integers representing the indices of input tensors which are masks. Defaults to an empty list.
        probability: A float value between 0 and 1 representing the probability of flipping the inputs horizontally. Defaults to 0.5.
        img_flip_kwargs: (Optional) A dictionary of keyword arguments to be passed to the image flip function. Defaults to an empty dictionary.
        box_flip_kwargs: (Optional) A dictionary of keyword arguments to be passed to the bounding box flip function. Defaults to an empty dictionary.

    Returns:
        The modified list of input tensors after applying horizontal flips to the specified indices.

    """
    box_indices = box_indices or list()
    mask_indices = mask_indices or list()
    img_flip_kwargs = img_flip_kwargs or dict()
    box_flip_kwargs = box_flip_kwargs or dict()
    coin_flip = dali_fn.random.coin_flip(
        probability=probability,
        dtype=dali_types.DALIDataType.BOOL,
    )
    for index in img_indices + mask_indices:
        inputs[index] = dali_fn.flip(
            inputs[index],
            horizontal=dali_fn.cast(coin_flip, dtype=dali_types.DALIDataType.INT32),
            **img_flip_kwargs,
        )

    for index in box_indices:
        inputs[index] = dali_fn.bb_flip(
            inputs[index],
            horizontal=dali_fn.cast(coin_flip, dtype=dali_types.DALIDataType.INT32),
            **box_flip_kwargs,
        )
    return inputs

def resize_exact(
    inputs: List[TensorList],
    height: int,
    width: int,
    img_indices: List[int],
    mask_indices: Optional[List[int]] = None,
    probability: float = 1.0,
    resize_kwargs: Optional[Dict] = None,
):
    """
    Resizes the given inputs to a specific height and width.

    Args:
        inputs: A list of TensorLists representing the input data.
        height: An integer representing the desired height of the resized image.
        width: An integer representing the desired width of the resized image.
        img_indices: A list of integers representing the indices of the input images to be resized.
        mask_indices: (optional) A list of integers representing the indices of the input masks to be resized. Defaults to an empty list.
        probability: (optional) A float representing the probability of applying the resize operation. Defaults to 1.0.
        resize_kwargs: (optional) A dictionary containing optional resize arguments.

    Returns:
        The modified list of input TensorLists after applying the resize operation.
    """
    mask_indices = mask_indices or list()
    resize_kwargs = resize_kwargs or dict()

    if dali_fn.random.coin_flip(
        probability=probability,
        dtype=dali_types.DALIDataType.BOOL,
    ):
        for index in img_indices + mask_indices:
            inputs[index] = dali_fn.resize(
                inputs[index], resize_x=width, resize_y=height, **resize_kwargs
            )

    return inputs

def generate_random_data(num_samples):
    """
    Generates random data consisting of RGB images with bounding boxes, labels, and instance segmentation masks.

    Args:
        num_samples (int): The number of samples to generate.

    Returns:
        list: A list of tuples, each containing the following elements:
            - image (numpy.ndarray): The RGB image.
            - boxes (list): A list of normalized bounding boxes in the form [x, y, w, h].
            - labels (list): A list of labels corresponding to each bounding box.
            - masks (list): A list of instance segmentation masks for each bounding box.
            - num_boxes (int): The number of bounding boxes in the image.
    """
    data = []
    for _ in range(num_samples):
        # a) Create a random RGB image with random dimensions
        height, width = np.random.randint(100, 800, size=2)
        image = np.random.randint(255, size=(height, width, 3), dtype=np.uint8)

        # b) Generate normalized bounding boxes of form [x, y, w, h]
        num_boxes = np.random.randint(1, 10)  # Random number of bounding boxes
        boxes = []
        # c) Labels for each box
        labels = []
        # d) Instance segmentation masks for each box
        masks = []
        for _ in range(num_boxes):
            # Generate unnormalized bounding box coordinates
            x1, y1 = np.random.randint(0, width - 10), np.random.randint(0, height - 10)
            w, h = np.random.randint(
                10, min(width - x1, width // 2)
            ), np.random.randint(10, min(height - y1, height // 2))
            x2, y2 = x1 + w, y1 + h

            # Normalize the bounding box coordinates
            normalized_box = [x1 / width, y1 / height, w / width, h / height]

            boxes.append(normalized_box)
            # Random label for the bounding box from {1, 2, 3}
            labels.append(np.random.choice([1, 2, 3]))

            # Create instance mask for the bounding box
            mask = np.zeros((height, width), dtype=np.uint8)
            cv2.rectangle(mask, (x1, y1), (x2, y2), color=255, thickness=-1)
            masks.append(mask)
        masks = np.stack(masks)

        # Append the data as a tuple
        data.append((image, boxes, labels, masks, [num_boxes]))

    return data

class MyDataset(object):
    def __init__(
        self, input_db: tf.data.Dataset, transforms: Optional[List[DictConfig]] = None
    ):
        super(MyDataset, self).__init__()
        self.dataset = input_db
        self._transform_config = transforms or list()
        self._initialized = False
        self._transform: Optional[Callable] = None

    def _create_external_sources(
        self, db: tf.data.Dataset
    ) -> List[DataNode | Sequence[DataNode]]:
        """
        Args:
            db: A tf.data.Dataset object representing the input data to be processed.

        Returns:
            A list of DataNode objects or a list of sequences of DataNode objects.
             Each DataNode represents an external source to be used in the DALI pipeline.

        """
        sources = list()
        index = 0
        for key, value in self.dali_input_dict.items():
            sources.append(
                dali_fn.external_source(
                    source=self._ds2gen(db, index),
                    name=f"input_{key}",
                    device="gpu",
                    batch=False,
                    dtype=value,
                )
            )
            index += 1
        return sources

    @staticmethod
    def _ds2gen(ds, index):
        """
        Args:
            ds: A dataset containing the data.
            index: The index of the element to be accessed in each data item.

        Returns:
            A generator that yields the element at the specified index in each data item of the dataset.
        """
        for x in ds:
            yield x[index].numpy()

    @property
    def dali_input_dict(self) -> Dict[str, DALIDataType.DATA_TYPE]:
        """

        The `dali_input_dict` method returns a dictionary with predefined keys and their corresponding values of
        type `DALIDataType.DATA_TYPE`.

        Example:

            input_dict = obj.dali_input_dict

        Parameters:
            None

        Returns:
            Dictionary: Dictionary with predefined keys and their corresponding values.

        """
        input_dict = dict(
            image=DALIDataType.UINT8,
            bboxes=DALIDataType.FLOAT,
            labels=DALIDataType.INT32,
            mask=DALIDataType.UINT8,
            num_objects=DALIDataType.INT32,
        )
        return input_dict

    def construct_pipeline(self, db: tf.data.Dataset, batch_size: int):
        """

        Args:
            db: A TensorFlow Dataset object containing the input data.
            batch_size: An integer representing the desired batch size for the pipeline.

        Returns:
            A TensorFlow Pipeline object constructed using DALI with the given parameters.

        """

        @dali.pipeline_def(
            enable_conditionals=True,
            batch_size=batch_size,
            num_threads=4,
            device_id=0,
            seed=20,
        )
        def _build_pipeline():
            sources = self._create_external_sources(db)
            t_list = list()
            if not self._initialized:
                if self._transform_config is None:
                    self._transform = lambda x: x
                else:
                    for index in range(len(self._transform_config)):
                        t_list.append(
                            instantiate(
                                OmegaConf.create(self._transform_config[index]),
                                _convert_="object",
                                _partial_=True,
                            )
                        )
                    self._transform = compose(*t_list)
                self._initialized = True
            pipeline_output = self.transform_data(sources)
            return tuple(pipeline_output)

        return _build_pipeline()

    def transform_data(self, sources):
        """
        Transforms the provided data using the given sources.

        Args:
            sources (list): A list of data sources used for the transformation.

        Returns:
            list: The transformed data obtained from the pipeline execution.
        """
        pipeline_output = self._transform(sources)
        return pipeline_output

    @property
    def dali_output_dtypes(self):
        """
        Gets the output data types for the DALI pipeline.

        Returns:
            A tuple of data types (tf.uint8, tf.float32, tf.int32, tf.uint8, tf.int32).
        """
        return tf.uint8, tf.float32, tf.int32, tf.uint8, tf.int32

    def get_tf_dataset(self, batch_size: int):
        """
        Args:
            batch_size (int): The number of samples in each batch of the dataset.

        Returns:
            tf.data.Dataset: A TensorFlow Dataset object containing the data.

        Raises:
            ValueError: If batch_size is not a positive integer.
        """
        pipeline = self.construct_pipeline(db=self.dataset, batch_size=batch_size)

        with tf.device("/gpu:0"):
            dataset = dali_tf.experimental.DALIDatasetWithInputs(
                pipeline=pipeline,
                output_dtypes=tuple([x for x in self.dali_output_dtypes]),
                device_id=0,
            )

        return dataset

num_samples = 5  # Specify the number of samples to generate
generated_data = generate_random_data(num_samples)

# Convert the generated data to a tf.data.Dataset
def generator():
    for sample in generated_data:
        yield sample

# Define the output types and shapes for the dataset

output_signature = (
    tf.TensorSpec(
        [None, None, 3], dtype=tf.uint8
    ),  # Images can have different sizes, hence [None, None, 3]
    tf.TensorSpec(
        [None, 4], dtype=tf.float32
    ),  # Bounding boxes are of the shape [None, 4]
    tf.TensorSpec([None], dtype=tf.int32),  # Labels are a 1D array of integers
    tf.TensorSpec(
        [None, None, None], dtype=tf.uint8
    ),  # Masks can have different sizes, hence [None, None, None]
    tf.TensorSpec([None], dtype=tf.int32),
)

# Create the tf.data.Dataset
dataset = tf.data.Dataset.from_generator(generator, output_signature=output_signature)

@hydra.main(version_base="1.3", config_path="configs_sample", config_name="config")
def my_app(cfg: DictConfig) -> None:
    if cfg.transforms is None:
        transform_config = None
    else:
        transform_config = list(cfg.transforms.values())

    tf_db = MyDataset(dataset, transforms=transform_config).get_tf_dataset(1)
    for data_sample in tf_db:
        print(data_sample)

if __name__ == "__main__":
    my_app()
ujjwalnur commented 2 days ago

Hi @JanuszL ,

When you run the above snippet , here is the error I am getting :

2024-07-01 11:51:50.352523: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/backend.py:47: Warning: DALI support for Python 3.12 is experimental and some functionalities may not work.
  deprecation_warning(
2024-07-01 11:51:52.252241: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.252624: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.252957: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.253281: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.292281: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.292713: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.293050: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.293373: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.293692: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.294003: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.294313: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.294633: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.624285: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.624685: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.625005: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.625310: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.625612: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.625903: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.626192: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.626479: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.626764: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.627051: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.627335: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.627627: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.649258: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.649584: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.649892: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.650193: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.650485: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.650773: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.651057: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.651339: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.651640: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.651890: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22279 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:01:00.0, compute capability: 8.9
2024-07-01 11:51:52.652391: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.652634: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 22279 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:41:00.0, compute capability: 8.9
2024-07-01 11:51:52.653008: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.653248: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 22279 MB memory:  -> device: 2, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:81:00.0, compute capability: 8.9
2024-07-01 11:51:52.653669: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-07-01 11:51:52.653909: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 22279 MB memory:  -> device: 3, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:c1:00.0, compute capability: 8.9
[2024-07-01 11:51:57,411][root][WARNING] - AutoGraph could not transform <function _resolve_container_value at 0x732968b7e3e0> and will run it as-is.
Cause:
To silence this warning, decorate the function with @nvidia.dali.pipeline.do_not_convert
[2024-07-01 11:51:59,171][root][WARNING] - AutoGraph could not transform <function _gcd_import at 0x732ca79280e0> and will run it as-is.
Cause: Unable to locate the source code of <function _gcd_import at 0x732ca79280e0>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call in the do_not_convert decorator. Original error: could not get source code
To silence this warning, decorate the function with @nvidia.dali.pipeline.do_not_convert
Error executing job with overrides: []
Traceback (most recent call last):
  File "/media/homes/ujjwal/research-2024/localize/reproduce.py", line 386, in <module>
    my_app()
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
           ^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
            ^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
        ^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/homes/ujjwal/research-2024/localize/reproduce.py", line 380, in my_app
    tf_db = MyDataset(dataset, transforms=transform_config).get_tf_dataset(1)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/homes/ujjwal/research-2024/localize/reproduce.py", line 331, in get_tf_dataset
    pipeline = self.construct_pipeline(db=self.dataset, batch_size=batch_size)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/homes/ujjwal/research-2024/localize/reproduce.py", line 295, in construct_pipeline
    return _build_pipeline()
           ^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/pipeline.py", line 1973, in create_pipeline
    _generate_graph(pipe, pipe_func, args, fn_kwargs)
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/pipeline.py", line 1824, in _generate_graph
    pipe_outputs = func(*fn_args, **fn_kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 697, in wrapper
    raise e.ag_error_metadata.to_exception(e)
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 694, in wrapper
    return converted_call(f, args, kwargs, options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 453, in converted_call
    result = converted_f(*effective_args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqnt0cbbj.py", line 54, in autograph___build_pipeline
    pipeline_output = ag__.converted_call(ag__.ld(self).transform_data, (ag__.ld(sources),), None, fscope)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 455, in converted_call
    result = converted_f(*effective_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file244q_hv9.py", line 19, in autograph__transform_data
    pipeline_output = ag__.converted_call(ag__.ld(self)._transform, (ag__.ld(sources),), None, fscope)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 353, in converted_call
    return _call_unconverted(f, args, kwargs, options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 470, in _call_unconverted
    return f(*args)
           ^^^^^^^^
  File "/tmp/__autograph_generated_filegubqb6d1.py", line 26, in <lambda>
    retval__1 = ag__.autograph_artifact(lambda x: ag__.converted_call(ag__.ld(f), (ag__.converted_call(ag__.ld(g), (ag__.ld(x),), None, fscope_1),), None, fscope_1))
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 353, in converted_call
    return _call_unconverted(f, args, kwargs, options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 470, in _call_unconverted
    return f(*args)
           ^^^^^^^^
  File "/tmp/__autograph_generated_filegubqb6d1.py", line 26, in <lambda>
    retval__1 = ag__.autograph_artifact(lambda x: ag__.converted_call(ag__.ld(f), (ag__.converted_call(ag__.ld(g), (ag__.ld(x),), None, fscope_1),), None, fscope_1))
                                                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 365, in converted_call
    return converted_call(
           ^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 453, in converted_call
    result = converted_f(*effective_args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileg01l5ise.py", line 26, in autograph__horizontal_flip
    ag__.for_stmt(ag__.ld(img_indices) + ag__.ld(mask_indices), None, loop_body, get_state, set_state, (), {'iterate_names': 'index'})
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/operators/control_flow.py", line 108, in for_stmt
    _py_for_stmt(iter_, extra_test, body, None, None)
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/operators/control_flow.py", line 124, in _py_for_stmt
    body(target)
  File "/tmp/__autograph_generated_fileg01l5ise.py", line 24, in loop_body
    ag__.ld(inputs)[ag__.ld(index)] = ag__.converted_call(ag__.ld(dali_fn).flip, (ag__.ld(inputs)[ag__.ld(index)],), dict(horizontal=ag__.converted_call(ag__.ld(dali_fn).cast, (ag__.ld(coin_flip),), dict(dtype=ag__.ld(dali_types).DALIDataType.INT32), fscope), **ag__.ld(img_flip_kwargs)), fscope)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 387, in converted_call
    return _call_unconverted(f, args, kwargs, options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_autograph/impl/api.py", line 469, in _call_unconverted
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/fn/__init__.py", line 99, in fn_wrapper
    return op_wrapper(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/fn/__init__.py", line 80, in op_wrapper
    return op_class(**init_args)(*inputs, **call_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/ops/__init__.py", line 628, in __call__
    _OperatorInstance(input_set, arg_inputs, args, self._init_args, self)
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/ops/__init__.py", line 379, in __init__
    inputs, arg_inputs = _conditionals.apply_conditional_split_to_args(inputs, arg_inputs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 518, in apply_conditional_split_to_args
    inputs = apply_conditional_split_to_branch_outputs(inputs, False)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 512, in apply_conditional_split_to_branch_outputs
    return _map_structure(apply_split, branch_outputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 58, in _map_structure
    return tree.map_structure(func, *structures, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/tree/__init__.py", line 428, in map_structure
    [func(*args) for args in zip(*map(flatten, structures))])
     ^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 503, in apply_split
    return apply_conditional_split(atom)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 481, in apply_conditional_split
    return this_condition_stack().preprocess_input(input)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 298, in preprocess_input
    stack_level = self._find_closest(data_node)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 240, in _find_closest
    raise ValueError(f"{data_node} was not produced within this trace.")
ValueError: in user code:

    File "/media/homes/ujjwal/research-2024/localize/reproduce.py", line 292, in _build_pipeline  *
        pipeline_output = self.transform_data(sources)
    File "/media/homes/ujjwal/research-2024/localize/reproduce.py", line 307, in transform_data  *
        pipeline_output = self._transform(sources)
    File "/media/homes/ujjwal/research-2024/localize/core/augmentations/ops.py", line 28, in horizontal_flip  *
        inputs[index] = dali_fn.flip(
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/fn/__init__.py", line 99, in fn_wrapper  **
        return op_wrapper(*inputs, **kwargs)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/fn/__init__.py", line 80, in op_wrapper
        return op_class(**init_args)(*inputs, **call_args)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/ops/__init__.py", line 628, in __call__
        _OperatorInstance(input_set, arg_inputs, args, self._init_args, self)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/ops/__init__.py", line 379, in __init__
        inputs, arg_inputs = _conditionals.apply_conditional_split_to_args(inputs, arg_inputs)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 518, in apply_conditional_split_to_args
        inputs = apply_conditional_split_to_branch_outputs(inputs, False)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 512, in apply_conditional_split_to_branch_outputs
        return _map_structure(apply_split, branch_outputs)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 58, in _map_structure
        return tree.map_structure(func, *structures, **kwargs)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/tree/__init__.py", line 428, in map_structure
        [func(*args) for args in zip(*map(flatten, structures))])
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 503, in apply_split
        return apply_conditional_split(atom)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 481, in apply_conditional_split
        return this_condition_stack().preprocess_input(input)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 298, in preprocess_input
        stack_level = self._find_closest(data_node)
    File "/home/ujjwal/anaconda3/envs/tf2.16/lib/python3.12/site-packages/nvidia/dali/_conditionals.py", line 240, in _find_closest
        raise ValueError(f"{data_node} was not produced within this trace.")

    ValueError: DataNode(name="__Resize_6", device="gpu") was not produced within this trace.
JanuszL commented 2 days ago

@ujjwalnur - can you upload a zip file with all the configs and the directory structure or something that will generate it? Copying them from the screenshots is not convenient.

ujjwalnur commented 2 days ago

@JanuszL Here you go.

reproduce.zip

ujjwalnur commented 2 days ago

As an additional context, if you comment out one of the augmentations from config.yaml, the code will run.

ujjwalnur commented 1 day ago

Were you able to reproduce the issue ?

JanuszL commented 1 day ago

Yes, the repro works fine. Thank you for that. I will look into it later and get back to you soon.

ujjwalnur commented 23 hours ago

Hi,

Any insights into this?

JanuszL commented 23 hours ago

@klecki can you look into it?