mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
593 stars 43 forks source link

Tensor support for gather() #151

Closed Linyou closed 1 year ago

Linyou commented 1 year ago

Hi! Thank this amazing work.

I am currently attempting to integrate drjit with PyTorch, but have encountered an issue. Here is a snippet of code that did not function properly with drjit:

def get_grid_feat_mini(xyzs: Array3f, table: Array2f, res: int):

    pos = xyzs * (res-1) + 0.5 
    pos_grid = dr.floor(pos)
    pos -= pos_grid

    a = pos[0]
    b = pos[1]
    c = pos[2]

    pos_x = pos_grid[0]
    pos_y = pos_grid[1]
    pos_z = pos_grid[2]

    index = (((pos_x) + ((pos_y) * res) + ((pos_z) * res * res)))

    w = (1.0 - a) * (1.0 - b) * (1.0 - c)

    out_feat = w * dr.gather(Array2f, table, index)

    return out_feat

@dr.wrap_ad(source='torch', target='drjit')
def run_test(xyz, hash_table):
    return get_grid_feat_mini(xyz, hash_table, 128)

device = torch.device('cuda:0')
xyz = torch.ones(1000000, 3, requires_grad=True, device=device)
table = torch.ones(2097281, 2, requires_grad=True, device=device)

grid_feat = run_test(xyz, table)

Here is the error info:

  File "/home/loyot/workspace/code/git_worktree/drjit/modules/drjit_test.py", line 106, in get_grid_index_mini
    out_feat = w * dr.gather(Array2f, table, index)
  File "/home/loyot/anaconda3/envs/torch2/lib/python3.9/site-packages/drjit/router.py", line 713, in gather
    raise Exception("gather(): Tensor type not supported! Should work "
drjit.Exception: gather(): Tensor type not supported! Should work with the underlying array instead. (e.g. tensor.array)

However, the code is running perfectly on drjit array, is there a way to make it work on a tensor?

njroussel commented 1 year ago

Hi @Linyou

As the error message suggests, you can access the underlying array of a tensor by calling my_tensor_variable.array. You can then use the gather operation on this array. Also, TensorXf types support the slicing operator my_tensor_variable[idx_d0, idx_d1, ... idx_dn] which might be easier to use in certain situations.

Linyou commented 1 year ago

Thank you for your response, @njroussel. I have updated the code to directly slice through table as follows:


out_feat = w * table[index]

However, I encountered another error:

File "/home/loyot/workspace/code/git_worktree/drjit/modules/drjit_test.py", line 112, in run_index
return get_grid_index_mini(xyz, hash_table, 128)
File "/home/loyot/workspace/code/git_worktree/drjit/modules/drjit_test.py", line 106, in get_grid_index_mini
out_feat = w * table[index]
File "/home/loyot/anaconda3/envs/torch2/lib/python3.9/site-packages/drjit/detail.py", line 701, in tensor_getitem
shape, index = slice_tensor(tensor.shape, slice_arg, tensor_t.Index)
File "/home/loyot/anaconda3/envs/torch2/lib/python3.9/site-packages/drjit/detail.py", line 643, in slice_tensor
raise TypeError("slice_tensor(): type '%s' cannot be used to index into a tensor!",
TypeError: ("slice_tensor(): type '%s' cannot be used to index into a tensor!", 'TensorXf')

It seems that the index obtained from pos is not a valid index and is being seen as type %s. Does this mean that I need to convert index somehow?

njroussel commented 1 year ago

The error message has a formatting issues, I'll fix that right away.

The slicing operator can only be used with pure Python integer values. Currently, you're trying to use a TensorXf as the indices. The type hints in your get_grid_feat_mini are wrong/confusing, both the inputs and outputs are of type TensorXf. To index into a TensorXf with a DrJit type you should use use something like dr.gather(mi.Float, my_tensor.array, drjit_uint_type_index). In your example, the index var is a tensor so you can most likely do dr.gather(mi.Float, my_tensor.array, index.array).