NVIDIA / warp

A Python framework for high performance GPU simulation and graphics
https://nvidia.github.io/warp/
Other
4.28k stars 243 forks source link

[QUESTION] How to use Autodifferentiation for soft-body parameters #346

Open rrzhang139 opened 2 weeks ago

rrzhang139 commented 2 weeks ago

Hello, thank you for building this first of all.

I have a question about differentiating the parameters in Warp. If I wanted to differentiate soft body parameters like k_mu or k_lambda in soft_grid, how would that generally work? Would I initialize some tensors and pass them into kernels to get that working? Do you have any code samples that does this for soft body simulation? Thanks!

amabilee commented 2 weeks ago

Hey there !

To differentiate soft body parameters like k_mu or k_lambda in Warp, you indeed need to initialize tensors and pass them into kernels. Warp supports automatic differentiation, which allows you to compute gradients of these parameters.

Here is the doc about it: https://nvidia.github.io/warp/modules/differentiability.html

shi-eric commented 2 weeks ago

I'm not an expert on the warp.sim things, but I'll attempt an answer. You would generally need to do two things:

rrzhang139 commented 2 weeks ago

Thank you!

rrzhang139 commented 2 weeks ago

Sorry, I tried implementing your method of computing gradients on Model.tet_materials array but it does not seem to be updating. Here is my optimization step:

        tape = wp.Tape()

        with tape:
            print(f"\nStart of simulation:")
            print(f"target k_mu: {self.target_model.tet_materials.numpy()[0][0]}")
            print(f"k_mu: {self.model.tet_materials.numpy()[0][0]}")

            for _ in range(self.num_frames):
                self.step()

            print(f"End of simulation:")
            print(f"k_mu: {self.model.tet_materials.numpy()[0][0]}")
            print(f"k_mu gradients: {self.model.tet_materials.grad.numpy()[0][0]}")
            loss = self.compute_loss()

        tape.backward(loss)
        tape.zero()

The gradients are 0.0, and I compute loss by calculating the MSE between self.state_0.particle_q and self.target_state_0.particle_q.

Here is my entire script in case you'd like to reproduce it. Thank you so much for your help!

import math
import os
import numpy as np
import torch
import warp as wp
import warp.sim
# from pxr import Usd, UsdGeom

from renderutils import SoftRenderer
from utils.logging import write_imglist_to_dir, write_imglist_to_gif

@wp.kernel
def compute_mse_loss(
    current_pos: wp.array(dtype=wp.vec3),
    target_pos: wp.array(dtype=wp.vec3), 
    loss: wp.array(dtype=float),
):
    # Get thread index
    idx = wp.tid()

    # Compute difference for this position
    diff = current_pos[idx] - target_pos[idx]
    squared_dist = wp.dot(diff, diff)

    # Atomically add to total loss
    wp.atomic_add(loss, 0, squared_dist)

@wp.kernel
def compute_image_mse_loss(
    rendered_frames: wp.array(dtype=wp.float32),
    target_frames: wp.array(dtype=wp.float32),
    loss: wp.array(dtype=float),
):
    # Get thread index
    idx = wp.tid()

    # Compute pixel-wise difference
    diff = rendered_frames[idx] - target_frames[idx]
    squared_diff = diff * diff

    # Atomic add to accumulate loss across all pixels
    wp.atomic_add(loss, 0, squared_diff)

class WarpFEMDemo:
    def __init__(self, 
                 num_frames=100,
                 device="cuda:0",
                 sim_duration=2.0,
                 physics_engine_rate=60,
                 sim_substeps=32,
                 learning_rate=0.01,
                 num_iterations=100,
                 optimize_param="k_mu",  # New parameter to control what we optimize
                 privileged_mode=True):

        # Simulation parameters
        self.frame_dt = 1.0 / physics_engine_rate
        self.sim_substeps = sim_substeps
        self.sim_dt = self.frame_dt / sim_substeps
        self.sim_time = 0.0
        self.num_frames = num_frames
        self.privileged_mode = privileged_mode

        # Initialize renderer
        self.renderer = SoftRenderer(camera_mode="look_at", device=device)
        self.setup_camera()

        # Add optimization parameters
        self.learning_rate = learning_rate
        self.num_iterations = num_iterations
        self.device = device

        # Replace SimpleModel with direct Warp array
        self.initial_pos = wp.vec3(0.0, 0.0, 0.0)
        self.initial_velocity = wp.vec3(-2.0, 0.0, 0.0)

        # Initialize material parameters
        self.k_mu = 1000.0
        self.k_lambda = 1000.0
        self.k_damp = 1.0
        self.density = 10.0

        self.optimize_param = optimize_param

        # Create target cube state
        # self.target = wp.vec3(-1.0, 0.0, 0.0)
        self.target = 3000.0
        self.build_target_state()

        # Build FEM model
        self.build_fem_model()

        # Initialize states for both models
        self.state_0 = self.model.state()
        self.state_1 = self.model.state()
        self.target_state_0 = self.target_model.state()
        self.target_state_1 = self.target_model.state()

        # Setup integrator
        self.integrator = wp.sim.SemiImplicitIntegrator()

        # Store frames for both actual and target simulations
        self.rendered_frames = []
        self.target_frames = []

        # CUDA optimization
        self.use_cuda_graph = wp.get_device().is_cuda
        if self.use_cuda_graph:
            with wp.ScopedCapture() as capture:
                self.simulate_step()
            self.graph = capture.graph
        self.loss = wp.zeros(1, dtype=float, device=self.device, requires_grad=True)

    def build_target_state(self):
        # Build a second cube as target
        target_builder = wp.sim.ModelBuilder()

        # Use same parameters as original cube
        cell_dim = [20, 4, 10]
        cell_size = [0.01, 0.005, 0.01]

        # Place target cube at final desired position
        if self.optimize_param == "k_mu":
            target_builder.add_soft_grid(
                pos=self.initial_pos,  # Different position from original cube
                rot=wp.quat_identity(),
                vel=wp.vec3(-2.0, 0.0, 0.0),
                dim_x=cell_dim[0],
                dim_y=cell_dim[1],
                dim_z=cell_dim[2],
                cell_x=cell_size[0],
                cell_y=cell_size[1],
                cell_z=cell_size[2],
                density=self.density,
                k_mu=self.target,
                k_lambda=self.k_lambda,
                k_damp=self.k_damp
        )

        self.target_model = target_builder.finalize(requires_grad=True)
        self.target_state = self.target_model.state()

    def setup_camera(self):
        camera_distance = 8.0
        elevation = 30.0
        azimuth = 0.0
        self.renderer.set_eye_from_angles(camera_distance, elevation, azimuth)

    def build_fem_model(self):
        builder = wp.sim.ModelBuilder()

        # FEM parameters
        cell_dim = [20, 4, 10]
        cell_size = [0.01, 0.005, 0.01]

        if self.optimize_param == "k_mu":
            builder.add_soft_grid(
                pos=self.initial_pos,
                rot=wp.quat_identity(),
                vel=self.initial_velocity,
                dim_x=cell_dim[0],
                dim_y=cell_dim[1],
                dim_z=cell_dim[2],
                cell_x=cell_size[0],
                cell_y=cell_size[1],
                cell_z=cell_size[2],
                density=self.density,
                k_mu=self.k_mu,
                k_lambda=self.k_lambda,
                k_damp=self.k_damp
            )

        self.model = builder.finalize(requires_grad=True)
        self.model.ground = True
        self.control = self.model.control()

    def simulate_step(self):
        wp.sim.collide(self.model, self.state_0)
        wp.sim.collide(self.target_model, self.target_state_0)

        for _ in range(self.sim_substeps):
            self.state_0.clear_forces()
            self.integrator.simulate(
                self.model, 
                self.state_0, 
                self.state_1, 
                self.sim_dt, 
                self.control
            )
            self.state_0, self.state_1 = self.state_1, self.state_0

            # Target model simulation
            self.target_state_0.clear_forces()
            self.integrator.simulate(
                self.target_model,
                self.target_state_0,
                self.target_state_1,
                self.sim_dt,
                self.target_model.control()
            )
            self.target_state_0, self.target_state_1 = self.target_state_1, self.target_state_0

    def step(self):
        if self.use_cuda_graph:
            wp.capture_launch(self.graph)
        else:
            self.simulate_step()
        self.sim_time += self.frame_dt

    def render(self):
        # Convert Warp state to torch tensors for renderer
        vertices = torch.from_numpy(self.state_0.particle_q.numpy()).float()
        faces = torch.from_numpy(self.model.tri_indices.numpy()).long()

        # Create simple textures
        textures = torch.ones(1, faces.shape[-2], 2, 3, device=self.renderer.device)
        textures[..., 2] = 0  # Make it yellow-ish

        rgba = self.renderer.forward(
            vertices.unsqueeze(0).to(self.renderer.device),
            faces.unsqueeze(0).to(self.renderer.device),
            textures
        )
        return rgba

    def simulate_target(self):
        # Reset target state
        target_state_0 = self.target_model.state()
        target_state_1 = self.target_model.state()

        frames = []
        for _ in range(self.num_frames):
            # Simulate target model
            wp.sim.collide(self.target_model, target_state_0)

            for _ in range(self.sim_substeps):
                target_state_0.clear_forces()
                self.integrator.simulate(
                    self.target_model,
                    target_state_0,
                    target_state_1,
                    self.sim_dt,
                    self.target_model.control()
                )
                target_state_0, target_state_1 = target_state_1, target_state_0

            # Render target frame
            vertices = torch.from_numpy(target_state_0.particle_q.numpy()).float()
            faces = torch.from_numpy(self.target_model.tri_indices.numpy()).long()

            textures = torch.ones(1, faces.shape[-2], 2, 3, device=self.renderer.device)
            textures[..., 0] = 0  # Make it cyan-ish

            rgba = self.renderer.forward(
                vertices.unsqueeze(0).to(self.renderer.device),
                faces.unsqueeze(0).to(self.renderer.device),
                textures
            )
            frames.append(rgba)

        return frames

    def compute_loss(self):
        if not self.privileged_mode:
            # Stack frames into tensors
            rendered = torch.stack(self.rendered_frames)
            target = torch.stack(self.target_frames)

            # Convert to numpy arrays for warp
            rendered_np = rendered.cpu().numpy()
            target_np = target.cpu().numpy()

            wp.launch(
                kernel=compute_image_mse_loss,
                dim=rendered_np.size,
                inputs=[
                    wp.array(rendered_np.flatten(), dtype=wp.float32),
                    wp.array(target_np.flatten(), dtype=wp.float32),
                    self.loss
                ],
                device=self.device
            )
        else:
            # Use existing privileged mode loss
            wp.launch(
                kernel=compute_mse_loss,
                dim=len(self.state_0.particle_q),
                inputs=[self.state_0.particle_q, self.target_state_0.particle_q, self.loss],
                device=self.device
            )
        return self.loss

    def optimize_step(self):
        tape = wp.Tape()

        with tape:
            print(f"\nStart of simulation:")
            print(f"target k_mu: {self.target_model.tet_materials.numpy()[0][0]}")
            print(f"k_mu: {self.model.tet_materials.numpy()[0][0]}")

            for _ in range(self.num_frames):
                self.step()

            print(f"End of simulation:")
            print(f"k_mu: {self.model.tet_materials.numpy()[0][0]}")
            print(f"k_mu gradients: {self.model.tet_materials.grad.numpy()[0][0]}")
            loss = self.compute_loss()

        tape.backward(loss)
        tape.zero()
        return loss

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--num_frames", type=int, default=300)
    parser.add_argument("--sim_duration", type=float, default=2.0)
    parser.add_argument("--physics_rate", type=int, default=60)
    parser.add_argument("--sim_substeps", type=int, default=32)
    parser.add_argument("--logdir", type=str, default="warplogs/warp-demo-fem")
    parser.add_argument("--learning_rate", type=float, default=0.01)
    parser.add_argument("--num_iterations", type=int, default=100)
    parser.add_argument("--verbose", type=bool, default=True)
    args = parser.parse_args()

    # Create demo instance
    demo = WarpFEMDemo(
        device=args.device,
        num_frames=args.num_frames,
        sim_duration=args.sim_duration,
        physics_engine_rate=args.physics_rate,
        sim_substeps=args.sim_substeps,
        learning_rate=args.learning_rate,
        num_iterations=args.num_iterations,
        # verbose=args.verbose
    )

    # Generate target frames once before optimization
    target_frames = demo.simulate_target()

    # Optimization loop
    for iteration in range(args.num_iterations):
        # Reset simulation state
        demo.state_0 = demo.model.state()
        demo.state_1 = demo.model.state()
        # Run simulation
        frames = []
        for _ in range(args.num_frames):
            demo.step()
            rgba = demo.render()
            frames.append(rgba)

        # Compute loss and optimize
        loss = demo.optimize_step()

        if args.verbose:
            print(f"Iteration {iteration}, Loss: {loss.numpy()[0]:.6f}")

        # Save intermediate results
        if iteration % 10 == 0:
            save_dir = os.path.join(args.logdir, f"iteration_{iteration}")
            os.makedirs(save_dir, exist_ok=True)
            write_imglist_to_gif(frames, os.path.join(save_dir, "sim.gif"), imgformat="rgba")

    # Save final output
    print(f"Saving to {args.logdir}")
    write_imglist_to_dir(frames, os.path.join(args.logdir, "frames"), imgformat="rgba")
    write_imglist_to_gif(frames, os.path.join(args.logdir, "final_sim.gif"), imgformat="rgba")

if __name__ == "__main__":
    import argparse
    main()
shi-eric commented 2 weeks ago

Hey @rrzhang139, I think your script makes reference to some local code so I can't run it on my system, but one issue I see is that you're only using two model states and swapping them at the end of each simulate_step(), which won't work when computing gradients because automatic differentiation requires all intermediate states to be available during the backward pass. See https://github.com/NVIDIA/warp/blob/3f9038d7234b036d7c690314b5a5d7e7ed449e75/warp/examples/optim/example_cloth_throw.py#L104-L107 as an example of creating enough model states so that gradients can be correctly propagated.

Also, in your first snippet, the .grad arrays won't be populated because your statements are inside the Tape() context. The Tape.backward() is what performs the backward pass, so you would inspect gradients after calling backward() and before calling Tape.zero().

@eric-heiden for viz

rrzhang139 commented 1 week ago

Thank you @shi-eric, my gradients were flipflopping that indicated the swapped states. The problem I have now is instability and non-convergence. Does warp have an example script on materials prediction? Inherently it seems difficult because you have to track a target model's movement trajectory which introduces a moving optimization target. Also the material parameter may not contain enough information to smoothly update the gradients.

Let me know what your thoughts are, I have pasted a gist of my updated code. And there should be no dependencies anymore so you can run the script entirely.

https://gist.github.com/rrzhang139/707d9920be003faf142c32e2ac8e892b

Thanks for your help!

shi-eric commented 1 week ago

Hey @rrzhang139, I don't have experience with that, but I'll ask around to see if anyone else from the team can weigh in. One of our favorite examples from the community is NCLaw: https://github.com/PingchuanMa/NCLaw Maybe they employ some techniques to get around the issues you are running into.

rrzhang139 commented 1 week ago

Thank you! @shi-eric One quick question that may help me. One improvement is to scale the gradients in logspace. Is there a way to do forward pass in simulation in regular space then convert to log space for gradient calculation? Would I need to modify the source code?