ros2 / rclpy

rclpy (ROS Client Library for Python)
Apache License 2.0
268 stars 221 forks source link

Bad interaction between `torch.compile` and `MultiThreadedExecutor` #1288

Closed alberthli closed 3 weeks ago

alberthli commented 1 month ago

Issue

Required Info

Description

I have an application where I have a medium-sized torch model that I'd like to run as quickly as possible with my hardware stack in the loop. To do so, I compile the model and call it in some timer callback. The problem is that when I use a MultiThreadedExecutor and have other callbacks that are part of different callback groups than the one running the compiled torch model, an error is thrown on the torch side. I'm posting this issue here because it's much more likely that ROS2 users are familiar with torch than the other way around, and could have some insight into workarounds.

EDIT: I can get around these errors if I keep everything on CPU, but some insight would still be nice! This runs about 2x slower on my computer than moving the model/data to GPU in my real application.

Here's a MWE that captures some of the undesirable behavior.

import rclpy
import torch
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node

torch.set_float32_matmul_precision('high')

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return self.lin(x)

class MyNode(Node):
    def __init__(self):
        super().__init__("cube_estimator")
        model = MyModule().to("cuda")

        # case 1: fails! this is what I want in my application
        self.model = torch.compile(model, mode="reduce-overhead")
        cbg1 = MutuallyExclusiveCallbackGroup()
        cbg2 = MutuallyExclusiveCallbackGroup()
        self.create_timer(0.001, self.callback1, callback_group=cbg1)
        self.create_timer(0.001, self.callback2, callback_group=cbg2)

        # case 2: succeeds
        # self.model = torch.compile(model, mode="reduce-overhead")
        # cbg3 = MutuallyExclusiveCallbackGroup()
        # self.create_timer(0.001, self.callback1, callback_group=cbg3)
        # self.create_timer(0.001, self.callback2, callback_group=cbg3)

        # case 3: succeeds!
        # self.model = model
        # cbg4 = MutuallyExclusiveCallbackGroup()
        # cbg5 = MutuallyExclusiveCallbackGroup()
        # self.create_timer(0.001, self.callback1, callback_group=cbg4)
        # self.create_timer(0.001, self.callback2, callback_group=cbg5)

    def callback1(self):
        self.model(torch.ones(100, device="cuda"))

    def callback2(self):
        self.get_logger().info("dummy")

def main(args=None):
    rclpy.init(args=args)
    my_node = MyNode()
    executor = MultiThreadedExecutor()
    executor.add_node(my_node)
    executor.spin()
    cube_estimator.destroy_node()
    rclpy.shutdown()

if __name__ == '__main__':
    main()

Running the above node results in the below error. Any insights into why this is happening (on either the torch or the ROS side) would be helpful!

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1102, in _record
    static_outputs = model(inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_root/cs/ccsrimifvcnflpxc2ziq5ubgcfo2brbcradisby3uiocjrt3li3y.py", line 41, in call
    extern_kernels.addmm(primals_2, reinterpret_tensor(primals_3, (1, 100), (100, 1), 0), reinterpret_tensor(primals_1, (100, 10), (1, 100), 0), alpha=1, beta=1, out=buf0)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/dev_ws/install/cube_estimation/lib/cube_estimation/test_node", line 33, in <module>
    sys.exit(load_entry_point('cube-estimation==0.0.1', 'console_scripts', 'test_node')())
  File "/home/dev_ws/build/cube_estimation/cube_estimation/test_node.py", line 55, in main
    executor.spin()
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 294, in spin
    self.spin_once()
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 794, in spin_once
[INFO] [1716423868.286648916] [cube_estimator]: dummy
    self._spin_once_impl(timeout_sec)
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 791, in _spin_once_impl
    future.result()
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/task.py", line 94, in result
    raise self.exception()
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/task.py", line 239, in __call__
    self._handler.send(None)
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 437, in handler
    await call_coroutine(entity, arg)
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 351, in _execute_timer
    await await_or_execute(tmr.callback)
  File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 107, in await_or_execute
    return callback(*args)
  File "/home/dev_ws/build/cube_estimation/cube_estimation/test_node.py", line 45, in callback1
    self.model(torch.ones(100, device="cuda"))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dev_ws/build/cube_estimation/cube_estimation/test_node.py", line 15, in forward
    def forward(self, x):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 917, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 88, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 505, in forward
    fw_outs = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 906, in __call__
    return self.get_current_callable()(inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 838, in run
    return compiled_fn(new_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 370, in deferred_cudagraphify
    return fn(inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 784, in run
    return model(new_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1757, in run
    out = self._run(new_inputs, function_id)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1831, in _run
    return self.record_function(new_inputs, function_id)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1862, in record_function
    node = CUDAGraphNode(
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 902, in __init__
    ] = self._record(wrapped_function.model, recording_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1094, in _record
    with preserve_rng_state(), torch.cuda.device(
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/graphs.py", line 184, in __exit__
    self.cuda_graph.capture_end()
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/graphs.py", line 82, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The following exception was never retrieved: beginAllocateToPool: already recording to mempool_id
audrow commented 3 weeks ago

It's not clear to me that this is a bug in ROS 2.

I think a better place for this is the Robotics Stack Exchange.