Closed xuan-li closed 2 months ago
To reproduce:
adj.py:
adj.py
import warp as wp @wp.func def overload_fn(x: float, y: float): return x * 3.0 + y / 3.0, y**2.5 @wp.func_grad(overload_fn) def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float): wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0 wp.adjoint[y] += y * adj_ret1 * 3.0 if __name__ == "__main__": wp.init() @wp.kernel def overload_kernel(x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32)): tid = wp.tid() overload_fn(x[tid], y[tid]) x = wp.array([1.0, 2.0, 3.0], dtype=wp.float32) y = wp.array([4.0, 5.0, 6.0], dtype=wp.float32) wp.launch(overload_kernel, inputs=[x, y], dim=x.shape[0])
test_adj.py:
test_adj.py
wp.init() @wp.kernel def overload_kernel(x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32)): tid = wp.tid() overload_fn(x[tid], y[tid]) x = wp.array([1.0, 2.0, 3.0], dtype=wp.float32) y = wp.array([4.0, 5.0, 6.0], dtype=wp.float32) wp.launch(overload_kernel, inputs=[x, y], dim=x.shape[0])
Execution of adj.py is fine. But I will encounter the following error if I execute test_adj.py:
Warp NVRTC compilation error 6: NVRTC_ERROR_COMPILATION (/buildAgent/work/a9ae500d09a78409/warp/native/warp.cu:1674) default_program(122): error: identifier "adj_overload_fn" is undefined
Thanks for reporting this! This should also be fixed in the next release.
This should now be fixed.
To reproduce:
adj.py
:test_adj.py
:Execution of
adj.py
is fine. But I will encounter the following error if I executetest_adj.py
: