pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.75k stars 22.58k forks source link

Torch Dynamo support for Flux T5 model #138196

Open asfiyab-nvidia opened 3 weeks ago

asfiyab-nvidia commented 3 weeks ago

🐛 Describe the bug

I'm using the script below to export the Flux T5 model to ONNX using torch.onnx.dynamo_export(). However, I run into an error due to missing support for fused_layer_norm_cuda.PyCapsule.rms_forward_affine.

The script below can be used to reproduce the issue:

import torch
from transformers import T5EncoderModel

# Load Model
model_dir = "black-forest-labs/FLUX.1-dev"
device = "cuda"
model = T5EncoderModel.from_pretrained(model_dir, subfolder="text_encoder_2").to(device)

# Define Input
inputs = (
    torch.zeros(1, 512, dtype=torch.int32, device=device)
)

# Export to ONNX
out = torch.onnx.dynamo_export(
    model,
    inputs,
)
out.save("t5_dynamo.onnx")

The error is pasted below:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/_exporter_legacy.py", line 1474, in dynamo_export
    ).export()
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/_exporter_legacy.py", line 1200, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 196, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1430, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 464, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 151, in wrapped
    return output_adapter.apply(model_func(*args, **kwargs), model=model)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1224, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
    return _compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 896, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 697, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 631, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2722, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1654, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1642, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 383, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 322, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 106, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1654, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1642, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 383, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 322, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 106, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1654, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1642, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 383, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 322, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 106, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1564, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1642, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 383, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 322, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 106, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1564, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 322, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 106, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1642, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 954, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 711, in call_method
    return self.call_apply(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 636, in call_apply
    ).call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/higher_order_ops.py", line 1826, in call_function
    (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/higher_order_ops.py", line 528, in speculate_subgraph
    raise ex
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/higher_order_ops.py", line 457, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 322, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 106, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2937, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3053, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 557, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1564, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 804, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 731, in call_function
    unimplemented(msg)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 283, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Graph break due to unsupported builtin fused_layer_norm_cuda.PyCapsule.rms_forward_affine. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

from user code:
   File "/usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py", line 1971, in forward
    encoder_outputs = self.encoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py", line 1106, in forward
    layer_outputs = layer_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py", line 686, in forward
    self_attention_outputs = self.layer[0](
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py", line 592, in forward
    normed_hidden_states = self.layer_norm(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/apex/normalization/fused_layer_norm.py", line 416, in forward
    return fused_rms_norm_affine(
  File "/usr/local/lib/python3.10/dist-packages/apex/normalization/fused_layer_norm.py", line 215, in fused_rms_norm_affine
    return FusedRMSNormAffineFunction.apply(*args)
  File "/usr/local/lib/python3.10/dist-packages/apex/normalization/fused_layer_norm.py", line 75, in forward
    output, invvar = fused_layer_norm_cuda.rms_forward_affine(

Versions

transformers 4.42.2 diffusers 0.31.0.dev0 torch 2.5.0a0+b465a5843b.nv24.9 (Nvidia NGC 24.09 PyTorch container)

justinchuby commented 3 weeks ago

Please test with torch.onnx.export(..., dynamo=True, report=True) using the latest torch-nightly. Attach the generated report if there is an error. Thanks!

xadupre commented 2 weeks ago

I was able to successfully export your model with torch.onnx.export(model, (dict(input_ids=inputs),), dynamo=True). You should use the nightly build.