nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

Compile TinyLlama failed: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool #664

Open ElEHsiang opened 6 months ago

ElEHsiang commented 6 months ago

I tried to compile TinyLlama-1.1B-Chat-v1.0 model to vmfb but failed. The parameter data type unmatch in torch.nn.functional.scaled_dot_product_attention(). How can I fix it?

PS. I based on commit 4a01c405843fd91badbea2a14fd19e0393aade8f due to iree-turbine does not have stateless_llama.py.

Command

python models/turbine_models/custom_models/stateless_llama.py --hf_model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0" --compile_to=vmfb --hf_auth_token="your token" --external_weights="safetensors" --quantization="unquantized" --precision="f32"

Error log

I0429 19:16:33.474000 139820168992576 torch/fx/experimental/symbolic_shapes.py:2724] create_symbol s0 = 2 for input0.size()[1] [2, 9223372036854775806] (fx/experimental/proxy_tensor.py:1050 in wrap_fake)
I0429 19:16:33.531000 139820168992576 torch/fx/experimental/symbolic_shapes.py:4035] eval s0 <= 2048 [guard added] (transformers/models/llama/modeling_llama.py:150 in forward)
Traceback (most recent call last):
  File "/scratch/yunh/SHARK-Turbine/models/turbine_models/custom_models/stateless_llama.py", line 488, in <module>
    mod_str, _ = export_transformer_model(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/models/turbine_models/custom_models/stateless_llama.py", line 403, in export_transformer_model
    inst = StateUpdateModule(context=Context(), import_to=import_to)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/compiled_module.py", line 652, in __new__
    do_export(proc_def)
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/compiled_module.py", line 649, in do_export
    trace.trace_py_func(invoke_with_self)
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/support/procedural/tracer.py", line 122, in trace_py_func
    return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/compiled_module.py", line 630, in invoke_with_self
    return proc_def.callable(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/models/turbine_models/custom_models/stateless_llama.py", line 195, in run_initialize
    token, *state = self.initialize(x, constraints=init_const)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/support/procedural/base.py", line 135, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/support/procedural/tracer.py", line 138, in handle_call
    return target.resolve_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/builtins/jittable.py", line 215, in resolve_call
    transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/passes/functorch.py", line 47, in functorch_functionalize
    new_gm = proxy_tensor.make_fx(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1081, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 541, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 793, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 559, in wrapped
    out = f(*tensors)
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/passes/functorch.py", line 65, in wrapped
    out = function(*args_functional)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/core/shark_turbine/aot/builtins/jittable.py", line 210, in flat_wrapped_f
    return self.wrapped_f(*pytorch_args, **pytorch_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/models/turbine_models/custom_models/stateless_llama.py", line 265, in initialize
    result = mod.forward(input_ids)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 771, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 496, in call_module
    return forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 764, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 771, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 496, in call_module
    return forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 764, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 771, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 496, in call_module
    return forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 764, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/yunh/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 728, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
IanNod commented 6 months ago

Our current implementation of scaled dot product attention doesn't yet support is_causal (should be coming soon). For now, I would suggest decomposing the attention op. @gpetters-amd, weren't you looking at adding that recently?

You can see an example of how we added that as an option in our SDXL pipeline here when you add the flag --decomp_attn https://github.com/nod-ai/SHARK-Turbine/blob/f919efe78903727d149c997e326f41b54ea1e147/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py#L117

NoumanAmir657 commented 6 days ago

Our current implementation of scaled dot product attention doesn't yet support is_causal (should be coming soon). For now, I would suggest decomposing the attention op. @gpetters-amd, weren't you looking at adding that recently?

You can see an example of how we added that as an option in our SDXL pipeline here when you add the flag --decomp_attn https://github.com/nod-ai/SHARK-Turbine/blob/f919efe78903727d149c997e326f41b54ea1e147/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py#L117

I am getting the same error even after adding the decomposition. These were added

    with decompositions.extend_aot_decompositions(
        from_current=True,
        add_ops=[
            torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
            torch.ops.aten._scaled_dot_product_flash_attention.default,
            torch.ops.aten.masked_fill_.Scalar,
            torch.ops.aten.copy,
        ],
    ):