facebookincubator / AITemplate

AITemplate is a Python framework which renders neural network into high performance CUDA/HIP C++ code. Specialized for FP16 TensorCore (NVIDIA GPU) and MatrixCore (AMD GPU) inference.
Apache License 2.0
4.55k stars 368 forks source link

[bug] Error when converting M2M100 transformer model with fx2ait #288

Closed chi2liu closed 1 year ago

chi2liu commented 1 year ago

I use the fx2ait to convert the M2M100 transformer model and it throw a pytorch error for `symbolically traced variables cannot be used as inputs to control flow. I know this is because the control flow is used inside the model. Is there any way for fx2ait to support or bypass this problem?

Original Code:

import unittest

import torch
from fx2ait.example.benchmark_utils import benchmark_function, verify_accuracy

class TestTransformerModule(unittest.TestCase):
    def test_transformer_encoder(self):
        from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
        model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
        tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
        inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        # verify_accuracy(model, inputs)

        results = []
        for batch_size in [1, 4, 16, 32, 64, 128, 256, 512]:
            inputs = [torch.randn(batch_size, 196, 768).half().cuda()]
            results.append(
                benchmark_function(self.__class__.__name__, 100, model, inputs)
            )
        for res in results:
            print(res)

if __name__ == "__main__":
    torch.manual_seed(0)
    unittest.main()

Error:

======================================================================
ERROR: test_transformer_encoder (__main__.TestTransformerModule)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "fx2ait/example/01_transformer_model/test_M2M100.py", line 83, in test_transformer_encoder
    benchmark_function(self.__class__.__name__, 100, model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/fx2ait-0.2.dev1-py3.8-linux-x86_64.egg/fx2ait/example/benchmark_utils.py", line 125, in benchmark_function
    mod = acc_tracer.trace(
  File "/usr/local/lib/python3.8/dist-packages/fx2ait-0.2.dev1-py3.8-linux-x86_64.egg/fx2ait/acc_tracer/acc_tracer.py", line 605, in trace
    traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list)
  File "/usr/local/lib/python3.8/dist-packages/fx2ait-0.2.dev1-py3.8-linux-x86_64.egg/fx2ait/acc_tracer/acc_tracer.py", line 523, in rewriter_base_trace
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
  File "/usr/local/lib/python3.8/dist-packages/fx2ait-0.2.dev1-py3.8-linux-x86_64.egg/fx2ait/acc_tracer/acc_tracer.py", line 319, in trace
    return super().trace(rewritten, concrete_args), rewritten
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/_symbolic_trace.py", line 739, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/m2m_100/modeling_m2m_100.py", line 1331, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/_symbolic_trace.py", line 717, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/_symbolic_trace.py", line 434, in call_module
    return forward(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/_symbolic_trace.py", line 710, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/m2m_100/modeling_m2m_100.py", line 1214, in forward
    elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/proxy.py", line 298, in __bool__
    return self.tracer.to_bool(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/proxy.py", line 174, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
ipiszy commented 1 year ago

Hi @chi2liu , fx tracing doesn't support control flow, and as a result fx2ait also doesn't support control flow. For now the simplest way is to use Pytorch eager for control flows, and fx2ait for the other part. Thanks.