NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
4.98k stars 609 forks source link

How to create a tensor in a custom python function within define_graph #5546

Open rachelglenn opened 5 days ago

rachelglenn commented 5 days ago

Describe the question.

How do create new torch tensors and have them go to the correct device. I would like to do things like taking the square of the tensor? I found this example: https://docs.nvidia.com/deeplearning/dali/archives/dali_1_18_0/user-guide/docs/examples/custom_operations/python_operator.html

def edit_images(image1, image2):
    assert image1.shape == image2.shape
    for i in range(c):
        h, w, c = image1.shape
        perturbation = torch.rand(h, w) 
        new_image1 = torch.zeros(h,w,c)
        new_image2 = torch.zeros(h,w,c)
        new_image1[:, :, i] = image1[:, :, i] * torch.square(perturbation)
        new_image2[:, :, i] = image2[:, :,i]  * torch.square(perturbation)
    return new_image1, new_image2

Check for duplicates

mzient commented 5 days ago

Hello @rachelglenn, I strongly advise against using PythonFunction for functionality with good native support. You can get tensors filled with random values with functions from dali.fn.random. Elementwise squaring can be achieved by simply by multiplying the tensors, like:

   # passing image as the argument will cause the function to return an array shaped like the image
   perturbation = fn.random.uniform(image1, range=[0, 1])  # this already includes channel
   pert_squared = perturbation * perturbation
   new_image1 = image1 * pert_squared
   new_image2 = image2 * pert_squared

BTW - it seems like the code is incorrect (swapped lines?):

    for i in range(c):
        h, w, c = image1.shape # c defined here, but loop over range(c) above
JanuszL commented 5 days ago

Still, if you like using torch_python_function, just use torch.cuda.device inside the callable.