NVIDIA / warp

A Python framework for high performance GPU simulation and graphics
https://nvidia.github.io/warp/
Other
1.75k stars 148 forks source link

Question About the Warp.launch #194

Closed wangpinzhi closed 3 weeks ago

wangpinzhi commented 2 months ago

Thank you for your wonderful work. I have a question about the warp. launch function. When there are none warp. array variables int the inputs/outputs, what should be the corresponding _adjinputs and _adjoutputs? I have an example, can you tell me if the code is correct?

import warp as wp
import numpy as np
import torch
from warp_utils import *

wp.init()

grid_size = (3, 3, 3)
@wp.kernel
def grid_normalization_and_gravity(
    grid_m: wp.array(dtype=float, ndim=3), 
    grid_v_in: wp.array(dtype=wp.vec3, ndim=3), # type: ignore
    gravitational_accelaration: wp.vec3,
    dt: float,
    grid_v_out: wp.array(dtype=wp.vec3, ndim=3), # type: ignore
):
    grid_x, grid_y, grid_z = wp.tid()
    v_out = grid_v_in[grid_x, grid_y, grid_z] * (1.0 / grid_m[grid_x, grid_y, grid_z])
    # add gravity
    v_out = v_out + dt * gravitational_accelaration
    grid_v_out[grid_x, grid_y, grid_z] = v_out

class Rosenbrock(torch.autograd.Function):
    @staticmethod
    def forward(ctx, grid_v_in, grid_m, gravitational_accelaration, dt):
        # ensure Torch operations complete before running Warp
        wp.synchronize_device()
        ctx.grid_v_in = wp.from_torch(grid_v_in, dtype=wp.vec3)
        ctx.grid_m = wp.from_torch(grid_m)

        ctx.gravitational_accelaration = gravitational_accelaration
        ctx.dt = dt

        # allocate output
        ctx.grid_v_out = wp.zeros(shape=grid_size, dtype=wp.vec3, requires_grad=True)

        wp.launch(
            kernel=grid_normalization_and_gravity,
            dim=grid_size,
            inputs=[ctx.grid_m, ctx.grid_v_in, ctx.gravitational_accelaration, ctx.dt],
            outputs=[ctx.grid_v_out]
        )

        # ensure Warp operations complete before returning data to Torch
        wp.synchronize_device()

        return wp.to_torch(ctx.grid_v_out)

    @staticmethod
    def backward(ctx, grid_v_out_grad):
        # ensure Torch operations complete before running Warp
        wp.synchronize_device()

        # map incoming Torch grads to our output variables
        ctx.grid_v_out.grad = wp.from_torch(grid_v_out_grad.contiguous(), dtype=wp.vec3)
        ctx.adj_grav = wp.vec3(0.0, 0.0, 0.0)
        print('before', ctx.grid_v_in.grad)

        wp.launch(
            kernel=grid_normalization_and_gravity,
            dim=grid_size,
            inputs=[ctx.grid_m, ctx.grid_v_in, ctx.gravitational_accelaration, ctx.dt],
            outputs=[ctx.grid_v_out],
            adj_inputs=[None, ctx.grid_v_in.grad, ctx.adj_grav, wp.float32(0.0)],
            adj_outputs=[ctx.grid_v_out.grad],
            adjoint=True
        )

        # ensure Warp operations complete before returning data to Torch
        wp.synchronize_device()
        print('after', ctx.grid_v_in.grad)
        # return adjoint w.r.t. inputs
        return (wp.to_torch(ctx.grid_v_in.grad), None, None, None)

num_points = 1500
learning_rate = 5e-2

torch_device = wp.device_to_torch(wp.get_device())

grid_v_in = torch.zeros((3, 3, 3, 3), dtype=torch.float32, requires_grad=True, device=torch_device)

grid_m = torch.randn((3, 3, 3), dtype=torch.float32, device=torch_device)
gravitational_accelaration = wp.vec3(1.0, 1.0, 1.0)
dt = 0.05

opt = torch.optim.Adam([grid_v_in], lr=learning_rate)

for _ in range(100):
    # step
    opt.zero_grad()
    grid_v_out = Rosenbrock.apply(grid_v_in, grid_m, gravitational_accelaration, dt)
    # grid_v_out.grad = torch.randn_like(grid_v_in).contiguous()
    grid_v_out.sum().backward()

    opt.step()
daedalus5 commented 2 months ago

I think it looks okay. You can pass None to an adj_input / adj_output when it hasn't been allocated / should not participate in backpropagation. You can also pass None if the gradient exists and is already associated with its corresponding Warp array. In this case you allocate a gradient array for grid_v_in and that gets mapped to the corresponding Warp array here: ctx.grid_v_in = wp.from_torch(grid_v_in, dtype=wp.vec3)