mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
601 stars 45 forks source link

drjit.gather does not support PyTrees defined with @dataclass #303

Closed zihay closed 1 month ago

zihay commented 1 month ago

Hi,

I'm using Dr.Jit for my project and using PyTrees as described in the documentation. I expect that drjit.gather should work with data classes defined using the @dataclass decorator. However, I'm encountering an issue where it does not. Code that doesn't work:

from dataclasses import dataclass
import drjit as dr
from drjit.auto import Array2, Int

def test_bbox():
    @dataclass
    class BoundingBox:
        p_min: Array2
        p_max: Array2

    boxes = dr.zeros(BoundingBox, 8)
    box = dr.gather(BoundingBox, boxes, Int(1))
    print(box.p_min)
    print(box.p_max)

test_bbox()

Error message:

TypeError: drjit.gather(<__main__.BoundingBox>): unsupported dtype!

Code that works

import drjit as dr
from drjit.auto import Array2, Int

def test_bbox():
    class BoundingBox:
        DRJIT_STRUCT = {'p_min': Array2, 'p_max': Array2}
        p_min: Array2
        p_max: Array2

    boxes = dr.zeros(BoundingBox, 8)
    box = dr.gather(BoundingBox, boxes, Int(1))
    print(box.p_min)
    print(box.p_max)

test_bbox()

In this version, I define the BoundingBox class with the DRJIT_STRUCT annotation, and drjit.gather works as expected. Based on the documentation, I expect that data classes should be fully supported in operations like drjit.gather, similar to custom classes with DRJIT_STRUCT. Is this a bug, or is there an additional step required to make data classes compatible with drjit.gather? Thank you for your assistance!

rtabbara commented 1 month ago

Hi @zihay ,

Thanks for reporting this. This feature has now been added into master (PR #305)