NVIDIA / warp

A Python framework for high performance GPU simulation and graphics
https://nvidia.github.io/warp/
Other
1.75k stars 148 forks source link

Cannot store customized adjoint function in a separate file #187

Closed xuan-li closed 2 months ago

xuan-li commented 3 months ago

To reproduce:

adj.py:

import warp as wp

@wp.func
def overload_fn(x: float, y: float):
    return x * 3.0 + y / 3.0, y**2.5

@wp.func_grad(overload_fn)
def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
    wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
    wp.adjoint[y] += y * adj_ret1 * 3.0

if __name__ == "__main__":
    wp.init()
    @wp.kernel
    def overload_kernel(x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32)):
        tid = wp.tid()
        overload_fn(x[tid], y[tid])

    x = wp.array([1.0, 2.0, 3.0], dtype=wp.float32)
    y = wp.array([4.0, 5.0, 6.0], dtype=wp.float32)
    wp.launch(overload_kernel, inputs=[x, y], dim=x.shape[0])

test_adj.py:

wp.init()
@wp.kernel
def overload_kernel(x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32)):
    tid = wp.tid()
    overload_fn(x[tid], y[tid])

x = wp.array([1.0, 2.0, 3.0], dtype=wp.float32)
y = wp.array([4.0, 5.0, 6.0], dtype=wp.float32)
wp.launch(overload_kernel, inputs=[x, y], dim=x.shape[0])

Execution of adj.py is fine. But I will encounter the following error if I execute test_adj.py:

Warp NVRTC compilation error 6: NVRTC_ERROR_COMPILATION (/buildAgent/work/a9ae500d09a78409/warp/native/warp.cu:1674)
default_program(122): error: identifier "adj_overload_fn" is undefined
shi-eric commented 3 months ago

Thanks for reporting this! This should also be fixed in the next release.

shi-eric commented 2 months ago

This should now be fixed.