nod-ai / SHARK-Turbine

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
82 stars 41 forks source link

StableDiffusion #152

Open stellaraccident opened 8 months ago

stellaraccident commented 8 months ago

Sync with Ean

monorimet commented 8 months ago
monorimet commented 7 months ago

@aviator19941 there are a few more missing pieces, mostly small preprocessors/schedulers that I'd like to have turbine implementations ready for. Please take a look at the following list of tasks and let me know if you have any questions.

gpetters94 commented 7 months ago

I can take Controlnet and the preprocessors @monorimet @aviator19941

aviator19941 commented 7 months ago

sounds good, I'll take (dynamic) SD schedulers then @gpetters94

monorimet commented 6 months ago

@aviator19941 can we implement VAE encode for img2img? we can split into two files or just have a separate API e.g. export_vae_encode_model and export_vae_decode_model or just use a keyword export_vae_model(hf_model_id, variant="encode", ... )

aviator19941 commented 6 months ago

@stellaraccident @dan-garvey VAE encode when calling sample() it fails the out = function(*args_functional) call in _functionalize_callabale() in functorch.py. I am not sure how to go about fixing this error. I have a small example that reproduces the same error that happens in latents = self.vae.encode(inp).latent_dist.sample(), but when I provide concrete integers in place of inp.shape in the repro's forward function, i.e. torch.randn(1, 4, 64, 64, ...) it is able to compile to torch IR. Do you have any suggestions on how to fix this?

The Traceback is below:

Traceback (most recent call last):
  File "/home/avinash/nod/SHARK-Turbine/python/turbine_models/custom_models/sd_inference/vae_encode.py", line 37, in <module>
    exported = aot.export(sample_model, example_x)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 199, in export
    cm = Exported(context=context, import_to="import")
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 538, in __new__
    do_export(proc_def)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 535, in do_export
    trace.trace_py_func(invoke_with_self)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/tracer.py", line 121, in trace_py_func
    return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 516, in invoke_with_self
    return proc_def.callable(self, *args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 183, in main
    return jittable(mdl.forward)(*args)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/base.py", line 137, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/tracer.py", line 137, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/builtins/jittable.py", line 207, in resolve_call
    transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/passes/functorch.py", line 47, in functorch_functionalize
    new_gm = proxy_tensor.make_fx(
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 809, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 468, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 485, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/passes/functorch.py", line 65, in wrapped
    out = function(*args_functional)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/builtins/jittable.py", line 202, in flat_wrapped_f
    return self.wrapped_f(*pytorch_args, **pytorch_kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/turbine_models/custom_models/sd_inference/vae_encode.py", line 30, in forward
    sample = torch.randn(inp.shape, generator=generator, device="cpu")
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 555, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 580, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 361, in proxy_call
    out = func(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 422, in constructors
    r = func(*args, **new_kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:3404: SymIntArrayRef expected to contain only concrete integers
dan-garvey commented 6 months ago
import torch
import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from typing import Optional
from shark_turbine.aot import *
from iree.compiler.ir import Context

class SampleModel(CompiledModule):

    def run_forward(self, inp=AbstractTensor(1,4,64,64, dtype=torch.float32)):
        sample = self.forward(inp)
        return sample

    @jittable
    def forward(inp) -> torch.FloatTensor:
        sample = torch.randn(inp.shape, device="cpu")
        # make sure sample is on the same device as the parameters and has same dtype
        sample = sample.to(device="cpu", dtype=torch.float32)
        return sample

sample_model = SampleModel(context=Context(), import_to="IMPORT")
print(str(CompiledModule.get_mlir_module(sample_model)))

seems to fix your issue. Can you model the full thing this way?

aviator19941 commented 6 months ago
import shark_turbine.aot as aot
import torch
import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from typing import Optional
from diffusers import AutoencoderKL
from shark_turbine.aot import *
from iree.compiler.ir import Context

hf_model_name = "CompVis/stable-diffusion-v1-4"

vae = AutoencoderKL.from_pretrained(
    hf_model_name,
    subfolder="vae",
)
class VaeModel(CompiledModule):

    def run_forward(self, inp=AbstractTensor(1, 3, 512, 512, dtype=torch.float32)):
        x = self.forward(inp)
        return x

    @jittable
    def forward(inp) -> torch.FloatTensor:
        latents = vae.encode(inp).latent_dist.sample()
        return 0.18215 * latents

vae_model = VaeModel(context=Context(), import_to="IMPORT")
print(str(CompiledModule.get_mlir_module(vae_model)))

I think the issue might be when the jittable calls another function that needs to use the input shape directly. Seems like this example also fails the SymIntArrayRef check.