comfyanonymous / ComfyUI

The most powerful and modular diffusion model GUI, api and backend with a graph/nodes interface.
https://www.comfy.org/
GNU General Public License v3.0
54.2k stars 5.74k forks source link

torch modules using inference_mode cause problems when doing acceleration #2194

Open zhaozhixu opened 10 months ago

zhaozhixu commented 10 months ago

Hi! Thank you so much for such an awesome repository!

The torch modules here are executed in inference_mode not no_grad, which causes some problems when doing some accelerations, such as torch.jit.trace or ONNX exporting. https://github.com/comfyanonymous/ComfyUI/blob/8112a0d9fcb80c341afa53798f62acdf618cee2b/execution.py#L329

From the commit log I can see it's been modified intentionally without explanation.

commit f67c00622f1259598ce1720bbcb483fbe6e5de68
Author: comfyanonymous <comfyanonymous@protonmail.com>
Date:   Wed Mar 22 03:48:26 2023 -0400

    Use inference_mode instead of no_grad.

Could the author please kindly explain the reason? @comfyanonymous

P.S. The exception when exporting jit-traced ONNX. Change the above line to with torch.no_grad(): solves the problem.

  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 133, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 124, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/facexlib/detection/retinaface.py", line 121, in forward
    out = self.body(inputs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torchvision/models/_utils.py", line 69, in forward
    x = module(x)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/ComfyUI/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
comfyanonymous commented 10 months ago

Inference mode disables more things than no_grad which is why I use it.

If you need to disable it in your code you can:

with torch.inference_mode(False):