mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
593 stars 43 forks source link

Inconsistent behavior of `UInt` constructor for PyTorch and numpy #111

Closed mworchel closed 1 year ago

mworchel commented 1 year ago

Hi,

I'm trying to convert integer PyTorch tensors to UInt arrays. This is a minimal reproducer

import drjit as dr
import torch

m = torch.tensor([1, 2, 3], dtype=torch.int32, device='cuda:0')

# Fails
dr.cuda.ad.UInt(m)

# Succeeds
dr.cuda.ad.UInt(m.cpu().numpy())

The line that fails produces the error message

Exception: Incompatible type!

which I understand, because my tensor contains signed integers. However, surprisingly, for numpy the conversion works, both for int32 and int64.

Ideally, I want to avoid copies to the CPU so I'm looking for a way to stay on the GPU without using the numpy detour. Unfortunately, PyTorch does not support uint32 or uint64 tensors (https://github.com/pytorch/pytorch/issues/58734). Is it possible to get the same conversion behavior for PyTorch as for numpy, i.e., being able to cast from int32 to uint32?

Thanks!

njroussel commented 1 year ago

Hi @mworchel

Maybe, I'm misunderstanding, but what is stopping you from doing the following:

import drjit as dr
import torch

m = torch.tensor([1, 2, 3], dtype=torch.int32, device='cuda:0')

signed = dr.cuda.ad.Int(m)
unsigned = dr.cuda.ad.UInt(signed)
mworchel commented 1 year ago

Hi @njroussel,

true, that's a possible way to get the conversion done!

I guess I was a bit too puzzled, why the conversion behavior is different for numpy and PyTorch. I've had a look at the source code and it seems like this conversion makes the numpy casting more relaxed

if o.dtype != self.Type.NumPy:
    o = o.astype(self.Type.NumPy)

Anyway, this is probably a non-issue so let's close this thread. Thanks! :)

njroussel commented 1 year ago

These castings are a bit of pain and I don't think this is the first time I've come acorss an oddity like mi.Uint(mi.Int(var)).

Great!