nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
94 stars 48 forks source link

DINO ViT model through Turbine #523

Open harishanand95 opened 7 months ago

harishanand95 commented 7 months ago

I'm trying to get DINO ViT model from huggingface transformers running through Turbine. Here is the code I'm using to make it work on turbine. It's based on the sd_inference code in the repo.

from transformers.models.vit.modeling_vit import ViTModel
import torch
from shark_turbine.aot import *
from iree.compiler.ir import Context
import iree.runtime as rt
from huggingface_hub import hf_hub_download

class DINOModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ViTModel(
            ViTModel.config_class.from_pretrained(
                hf_hub_download(repo_id="facebook/dino-vitb16", filename="config.json")
            )
        )

    def forward(self, img):
        return self.model.forward(pixel_values=img, interpolate_pos_encoding=True)

def dino_export(device="cpu", target_triple="", compile_to="vmfb", max_alloc="4294967296"):
    model = DINOModel()

    class CompiledDINOModel(CompiledModule):
        params = export_parameters(model)

        def main(
            self,
            image=AbstractTensor(1, 3, 512, 512, dtype=torch.float32),
        ):
            return jittable(model.forward)(image)

    import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
    inst = CompiledDINOModel(context=Context(), import_to=import_to)
    module_str = str(CompiledModule.get_mlir_module(inst))

dino_export()

This fails on line inst = CompiledDINOModel(context=Context(), import_to=import_to) for an error that is related to mismatch args in bicubic2d interpolation. Here is the exact place in transformers where it fails https://github.com/huggingface/transformers/blob/469c13280d77a75be626da4f8e918e9f24e4f80f/src/transformers/models/vit/modeling_vit.py#L105C1-L110C10.

    patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
    )

AFAIK, this scale_factor is expected to be a tuple of floats, but during turbine tracing its a tuple of SymFloat, which fails. Doing a float(h0), float(w0) didn't work either, as those operations also fails with a similar error (unable to convert SymFloat to float?).

Here is the error message:

File "/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/transformers/models/vit/modeling_vit.py", line 137, in forward
    embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/transformers/models/vit/modeling_vit.py", line 105, in interpolate_pos_encoding
    patch_pos_embed = nn.functional.interpolate(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/torch/nn/functional.py", line 4046, in interpolate
    return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: upsample_bicubic2d() received an invalid combination of arguments - got (FakeTensor, NoneType, bool, tuple), but expected one of:
 * (Tensor input, tuple of ints output_size, bool align_corners, tuple of floats scale_factors)
      didn't match because some of the arguments have invalid types: (FakeTensor, NoneType, bool, tuple of (SymFloat, SymFloat))
 * (Tensor input, tuple of ints output_size, bool align_corners, float scales_h, float scales_w, *, Tensor out)

In our case, the input is always going to be of shape (1, 3, 512, 512), where height and width is expected to be always 512.

How to fix this error?

Thanks!

stellaraccident commented 7 months ago

These issues are pure-Torch tracing related, so I suggest that we pivot this to work the problem at the Torch level. Either it is somewhat likely that it is fixed in a more recent torch version or we need to fix something in Torch. Either way, we need a pure repro. Given that you are just using static shapes and a pretty vanilla invocation, let's just use Torch APIs.

I'd recommend that you upgrade to a recent Torch (and please always include which torch version you are using) and follow the examples here: https://pytorch.org/docs/stable/export.html#overview

If we get this working, the rest should flow. For reference, they are doing a lot of active work in this area, and I do all of my dev on torch nightlies for now (whenever I have any kind of issue). Saves a lot of looking for issues that have already been fixed.

harishanand95 commented 7 months ago

Cool, I have changed to latest torch. Turbine jittable() code I shared before still fails but torch.export.export() worked.

# pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.3.0.dev20240308%2Bcpu-cp311-cp311-linux_x86_64.whl
import torch  # 2.3.0.dev20240308+cpu
from torch.export import export
from transformers.models.vit.modeling_vit import ViTModel  # 4.35.0
from huggingface_hub import hf_hub_download

class DINOModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ViTModel(
            ViTModel.config_class.from_pretrained(
                hf_hub_download(repo_id="facebook/dino-vitb16", filename="config.json")
            )
        )

    def forward(self, img):
        return self.model.forward(pixel_values=img, interpolate_pos_encoding=True)

example_args = (torch.randn(1, 3, 512, 512),)

exported_program = export(DINOModel(), args=example_args)
with open("output.txt", "w") as f:
    print(exported_program, file=f)

Here is the FX output printed. https://gist.github.com/harishanand95/d91b741e4e1e8e6b359fadb355295777 , interpolate function exists.


I ended up following some of the code written in paged llama to get it to run with IREE from the fx code. Now I'm hitting an iree-compile error now.

# torch                     2.3.0.dev20240308+cpu
# iree-compiler             20240226.813
# iree-runtime              20240226.813

import torch
from torch.export import export
from transformers.models.vit.modeling_vit import ViTModel  # 4.35.0
from huggingface_hub import hf_hub_download
from shark_turbine.importers.fx_importer import FxImporter
import iree.compiler as ireec

class DINOModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ViTModel(
            ViTModel.config_class.from_pretrained(
                hf_hub_download(repo_id="facebook/dino-vitb16", filename="config.json")
            )
        )

    def forward(self, img):
        return self.model.forward(pixel_values=img, interpolate_pos_encoding=True)

# https://github.com/nod-ai/SHARK-Turbine/blob/main/models/turbine_models/custom_models/sd_inference/utils.py#L32
def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
    flags = [
        "--iree-input-type=torch",
        "--mlir-print-debuginfo",
        "--mlir-print-op-on-diagnostic=false",
        "--iree-llvmcpu-target-cpu-features=host",
        "--iree-llvmcpu-target-triple=x86_64-linux-gnu",
        "--iree-stream-resource-index-bits=64",
        "--iree-vm-target-index-bits=64",
        "--iree-flow-inline-constants-max-byte-length=1",
    ]
    if device == "cpu":
        flags.append("--iree-llvmcpu-enable-ukernels=all")
        device = "llvm-cpu"
    elif device == "vulkan":
        flags.extend(
            [
                "--iree-hal-target-backends=vulkan-spirv",
                "--iree-vulkan-target-triple=" + target_triple,
                "--iree-stream-resource-max-allocation-size=" + max_alloc,
            ]
        )
    elif device == "rocm":
        flags.extend(
            [
                "--iree-hal-target-backends=rocm",
                "--iree-rocm-target-chip=" + target_triple,
                "--iree-rocm-link-bc=true",
                "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
                "--iree-vm-bytecode-module-strip-source-map=true",
                "--iree-opt-strip-assertions=true",
                "--iree-vm-target-truncate-unsupported-floats",
            ]
        )
    elif device == "cuda":
        flags.extend(
            [
                "--iree-hal-target-backends=cuda",
                "--iree-hal-cuda-llvm-target-arch=" + target_triple,
                "--iree-vm-bytecode-module-strip-source-map=true",
                "--iree-vm-target-truncate-unsupported-floats",
            ]
        )
    else:
        print("incorrect device: ", device)

    flatbuffer_blob = ireec.compile_str(
        module_str,
        target_backends=[device],
        extra_args=flags,
    )
    with open(f"{safe_name}.vmfb", "wb+") as f:
        f.write(flatbuffer_blob)
    print("Saved to", safe_name + ".vmfb")

def dino_export(device="cpu", target_triple="", compile_to="vmfb", max_alloc="4294967296"):
    model = DINOModel()
    example_args = (torch.randn(1, 3, 512, 512),)
    prog = export(model, args=example_args) # WORKS!!

    # https://github.com/nod-ai/SHARK-Turbine/blob/d5f4f9a67b4bffd5517e80e543b713c6b3ce2afe/llm/scripts/validate_paged_llama_model.py#L141
    importer = FxImporter()
    importer.import_program(prog, func_name="dino_model")

    output_file = "dino.mlir"
    print("Saving to:", output_file)
    with open(output_file, "w") as f:
        importer.module_op.print(file=f, binary=False) # needs file=f to get mlir text file

    module_str = str(importer.module_op) # iree.compiler._mlir_libs._mlir.ir.Operation
    compile_to_vmfb(module_str, device, target_triple, max_alloc, "dino_model") # Fails!

dino_export()
# dino_export(device="vulkan", target_triple="ampere-rtx4070-linux", max_alloc="4294967296")

Error:

$ python turbine/dino_torch_turbine_export.py 
/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Traceback (most recent call last):
  File "/home/harish/projects/turbine/turbine-lrm/turbine/dino_torch_turbine_export.py", line 98, in <module>
    dino_export()
  File "/home/harish/projects/turbine/turbine-lrm/turbine/dino_torch_turbine_export.py", line 96, in dino_export
    compile_to_vmfb(module_str, device, target_triple, max_alloc, "dino_model") # Fails!
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/harish/projects/turbine/turbine-lrm/turbine/dino_torch_turbine_export.py", line 72, in compile_to_vmfb
    flatbuffer_blob = ireec.compile_str(
                      ^^^^^^^^^^^^^^^^^^
  File "/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/iree/compiler/tools/core.py", line 299, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/iree/compiler/tools/binaries.py", line 198, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile
Error code: 1
Diagnostics:
<stdin>:178:11: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %57 = torch.operator "torch.aten._unsafe_index.Tensor"(%23, %56) : (!torch.vtensor<[1,768,14,14],f32>, !torch.list<optional<vtensor>>) -> !torch.vtensor<[1,768,32,32],f32>
          ^

Invoked with:
 iree-compile /home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-embedded-linker-path=/home/harish/venvs/nv_turbine-lrm/lib/python3.11/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-flow-inline-constants-max-byte-length=1 --iree-llvmcpu-enable-ukernels=all

Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers.

MLIR file is huge (~700MB), tried this on the MLIR to reduce size it fails with the same error.

$ iree-compile --compile-to=input --mlir-elide-elementsattrs-if-larger=16 dino.mlir
dino.mlir:178:11: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %57 = torch.operator "torch.aten._unsafe_index.Tensor"(%23, %56) : (!torch.vtensor<[1,768,14,14],f32>, !torch.list<optional<vtensor>>) -> !torch.vtensor<[1,768,32,32],f32>
          ^
dino.mlir:178:11: note: see current operation: %288 = "torch.operator"(%250, %287) <{name = "torch.aten._unsafe_index.Tensor"}> : (!torch.vtensor<[1,768,14,14],f32>, !torch.list<optional<vtensor>>) -> !torch.vtensor<[1,768,32,32],f32>

I could be using the wrong dialect here, not sure. Please let me know how to reach the iree-compiler correctly from the FxImporter so that vmfb can be generated and any pitfalls here, thanks!

stellaraccident commented 7 months ago

Will need a torch-mlir implementation for that op. @rsuderman

We also have techniques to produce versions of these things without parameters inlined but I'm still working on getting them adapted to the pure torch.export path. Will get there but better worked around for now.

stellaraccident commented 7 months ago

Just confirming that getting past the torch error only needed a torch version bump? What version did you end up on?

harishanand95 commented 7 months ago

Yes, 2.3.0.dev20240308+cpu

harishanand95 commented 6 months ago

We now have this model working with the latest PT2.3 support + decompositions. Thanks! However, the iree compilation takes like 20 mins. I have not validated the results either.. will update

# torch                     2.3.0+cpu
# iree-compiler             20240327.844
# iree-runtime              20240327.844
# shark-turbine             0.9.7.dev1

import torch
import shark_turbine.aot as aot
from transformers.models.vit.modeling_vit import ViTModel  # 4.35.0
from huggingface_hub import hf_hub_download
import iree.compiler as ireec

class DINOModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ViTModel(
            ViTModel.config_class.from_pretrained(
                hf_hub_download(repo_id="facebook/dino-vitb16", filename="config.json")
            )
        )

    def forward(self, img):
        return self.model.forward(pixel_values=img, interpolate_pos_encoding=True)

def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
    flags = [
        "--iree-input-type=torch",
        "--mlir-print-debuginfo",
        "--mlir-print-op-on-diagnostic=false",
        "--iree-llvmcpu-target-cpu-features=host",
        "--iree-llvmcpu-target-triple=x86_64-linux-gnu",
        "--iree-stream-resource-index-bits=64",
        "--iree-vm-target-index-bits=64",
        "--iree-flow-inline-constants-max-byte-length=1",
    ]
    if device == "cpu":
        flags.append("--iree-llvmcpu-enable-ukernels=all")
        device = "llvm-cpu"
    elif device == "vulkan":
        flags.extend(
            [
                "--iree-hal-target-backends=vulkan-spirv",
                "--iree-vulkan-target-triple=" + target_triple,
                "--iree-stream-resource-max-allocation-size=" + max_alloc,
            ]
        )
    elif device == "rocm":
        flags.extend(
            [
                "--iree-hal-target-backends=rocm",
                "--iree-rocm-target-chip=" + target_triple,
                "--iree-rocm-link-bc=true",
                "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
                "--iree-vm-bytecode-module-strip-source-map=true",
                "--iree-opt-strip-assertions=true",
                "--iree-vm-target-truncate-unsupported-floats",
            ]
        )
    elif device == "cuda":
        flags.extend(
            [
                "--iree-hal-target-backends=cuda",
                "--iree-hal-cuda-llvm-target-arch=" + target_triple,
                "--iree-vm-bytecode-module-strip-source-map=true",
                "--iree-vm-target-truncate-unsupported-floats",
            ]
        )
    else:
        print("incorrect device: ", device)

    flatbuffer_blob = ireec.compile_str(
        module_str,
        target_backends=[device],
        extra_args=flags,
    )
    with open(f"{safe_name}.vmfb", "wb+") as f:
        f.write(flatbuffer_blob)
    print("Saved to", safe_name + ".vmfb")

def dino_export(device="cpu", target_triple="", compile_to="vmfb", max_alloc="4294967296"):
    model = DINOModel()
    example_args = (torch.randn(1, 3, 512, 512),)
    prog = aot.export(model, args=example_args)
    module_str = str(prog.mlir_module)
    with open(f"dino_model.mlir", "w+") as f:
        f.write(module_str)
    compile_to_vmfb(module_str, device, target_triple, max_alloc, "dino_model")

dino_export()

MLIR file

ScottTodd commented 6 months ago

One thing that's a bit weird about the code from the previous comment is the --iree-hal-target-backends= and device handling. It looks like the code as-is might be passing flags like --iree-hal-target-backends=vulkan-spirv twice. I'd check the logs for the exact command used.

This code:

    flatbuffer_blob = ireec.compile_str(
        module_str,
        target_backends=[device],
        extra_args=flags,
    )

goes down to here: https://github.com/openxla/iree/blob/11d22592941b934d47289fe03441a3ed07f14f08/compiler/bindings/python/iree/compiler/tools/core.py#L188-L189 , which adds the target backends flag for you:

    for target_backend in options.target_backends:
        cl.append(f"--iree-hal-target-backends={target_backend}")

I'd either fully handle that flag yourself:

    if device == "cpu":
-       flags.append("--iree-llvmcpu-enable-ukernels=all")
-       device = "llvm-cpu"
+       flags.extend(
+           [
+               "--iree-hal-target-backends=llvm-cpu",
+               "--iree-llvmcpu-target-cpu-features=host",
+               "--iree-llvmcpu-target-triple=x86_64-linux-gnu",
+               "--iree-llvmcpu-enable-ukernels=all",
+           ]
+       )

    flatbuffer_blob = ireec.compile_str(
        module_str,
-       target_backends=[device],
        extra_args=flags,
    )

OR consistently pass a value to target_backends:

    elif device == "vulkan":
        flags.extend(
            [
-               "--iree-hal-target-backends=vulkan-spirv",
                "--iree-vulkan-target-triple=" + target_triple,
                "--iree-stream-resource-max-allocation-size=" + max_alloc,
            ]
        )
+       target_backend = vulkan-spirv

Here are some trace screenshots from compiling linked MLIR file on my machine:

harishanand95 commented 6 months ago

Results look good too, runs in seconds now. Removing --iree-llvmcpu-target-cpu-features=host helps. I ended up using .compile() in the ExportOutput..

# torch                     2.3.0+cpu
# iree-compiler             20240327.844
# iree-runtime              20240327.844
# shark-turbine             0.9.7.dev1

import torch
import shark_turbine.aot as aot
from transformers.models.vit.modeling_vit import ViTModel  # 4.35.0
from huggingface_hub import hf_hub_download
import iree.compiler as ireec
import iree.runtime as rt

class DINOModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ViTModel(
            ViTModel.config_class.from_pretrained(
                hf_hub_download(repo_id="facebook/dino-vitb16", filename="config.json")
            )
        )

    def forward(self, img):
        return self.model.forward(pixel_values=img, interpolate_pos_encoding=True, return_dict=False)

def dino_export(model, example_args):
    prog = aot.export(model, args=example_args)
    with open(f"dino_model.mlir", "w+") as f: 
        f.write(str(prog.mlir_module))
    return prog.compile(save_to=None, target_backends="llvm-cpu") # try vulkan-spirv

def shark_infer(x, compiled_binary):
    config = rt.Config("local-task")
    vmm = rt.load_vm_module(
        rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
        config,
    )
    y = vmm.main(x)
    return y

model = DINOModel()
example_args = (torch.randn(1, 3, 512, 512),)
compiled_binary = dino_export(model, example_args)
y_p = model(*example_args)
y_s = shark_infer(*example_args, compiled_binary)

from numpy.testing import assert_almost_equal
assert_almost_equal(y_p[0].detach().numpy(), y_s[0].to_host(), decimal=3)
ScottTodd commented 6 months ago

Discussing a bit offline. Something in the specific --iree-llvmcpu-target-cpu-features used on that machine is realllly slowing down the compiler. Going to try narrowing down which features and on what IR.