NVIDIA / warp

A Python framework for high performance GPU simulation and graphics
https://nvidia.github.io/warp/
Other
1.75k stars 148 forks source link

Multiple inequality a < b < c behavior not consistent with python language #162

Closed HaoyangShi-SLC closed 5 months ago

HaoyangShi-SLC commented 6 months ago

Warp kernel seems to treat it as (a < b) < c instead of (a < b) and (b < c). I ran a simple test case to reproduce it

import warp as wp
wp.init()
a = wp.zeros(shape = (10), dtype = wp.bool)
@wp.kernel
def test(output: wp.array(dtype = wp.bool)):
    a = wp.tid()
    output[a] = 0 < a < 1

@wp.kernel
def test2(output: wp.array(dtype = wp.bool)):
    a = wp.tid()
    output[a] = 0 < a and a < 1

@wp.kernel
def test3(output: wp.array(dtype = wp.bool)):
    a = float(wp.tid())
    output[wp.tid()] = 0.0 < a < 1.0

wp.launch(test, dim = (10), inputs=[a])
print(a.numpy())
wp.launch(test2, dim = (10), inputs=[a])
print(a.numpy())
wp.launch(test3, dim = (10), inputs=[a])
print(a.numpy())

ref = []
for i in range(10):
    ref.append(0 < i < 1)
print(ref)

and then I get

> python .\test.py
Warp 1.0.0-beta.2 initialized:
   CUDA Toolkit: 12.1, Driver: 12.2
   Devices:
     "cpu"    | Intel64 Family 6 Model 154 Stepping 3, GenuineIntel
     "cuda:0" | NVIDIA GeForce RTX 3070 Ti Laptop GPU (sm_86)
   Kernel cache: ...
Module __main__ load on device 'cuda:0' took 1.80 ms
[ True False False False False False False False False False]
[False False False False False False False False False False]
[ True False False False False False False False False False]
[False, False, False, False, False, False, False, False, False, False]

I'm quite amazed that when a, b, c are floats this program still manages to compile and run, producing some incorrect results. It just doesn't make any sense because warp currently does not have implicit type conversion and if a < b produces a bool, it shall get type error in the later comparison anyway. This is not specified in the warp documents and may cause confusion to others.

c0d1f1ed commented 6 months ago

This was recently addressed: https://github.com/NVIDIA/warp/commit/e2fc69378cb609675d5b414d9cd52fc2398baede. Please try upgrading to the latest release.