Open stellaraccident opened 1 year ago
@aviator19941 there are a few more missing pieces, mostly small preprocessors/schedulers that I'd like to have turbine implementations ready for. Please take a look at the following list of tasks and let me know if you have any questions.
I can take Controlnet and the preprocessors @monorimet @aviator19941
sounds good, I'll take (dynamic) SD schedulers then @gpetters94
@aviator19941 can we implement VAE encode for img2img? we can split into two files or just have a separate API e.g. export_vae_encode_model
and export_vae_decode_model
or just use a keyword export_vae_model(hf_model_id, variant="encode", ... )
@stellaraccident @dan-garvey VAE encode when calling sample() it fails the out = function(*args_functional)
call in _functionalize_callabale() in functorch.py
. I am not sure how to go about fixing this error. I have a small example that reproduces the same error that happens in latents = self.vae.encode(inp).latent_dist.sample()
, but when I provide concrete integers in place of inp.shape
in the repro's forward function, i.e. torch.randn(1, 4, 64, 64, ...) it is able to compile to torch IR. Do you have any suggestions on how to fix this?
The Traceback is below:
Traceback (most recent call last):
File "/home/avinash/nod/SHARK-Turbine/python/turbine_models/custom_models/sd_inference/vae_encode.py", line 37, in <module>
exported = aot.export(sample_model, example_x)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 199, in export
cm = Exported(context=context, import_to="import")
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 538, in __new__
do_export(proc_def)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 535, in do_export
trace.trace_py_func(invoke_with_self)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/tracer.py", line 121, in trace_py_func
return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 516, in invoke_with_self
return proc_def.callable(self, *args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 183, in main
return jittable(mdl.forward)(*args)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/base.py", line 137, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/tracer.py", line 137, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/builtins/jittable.py", line 207, in resolve_call
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/passes/functorch.py", line 47, in functorch_functionalize
new_gm = proxy_tensor.make_fx(
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 809, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 468, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
(self.create_arg(fn(*args)),),
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 485, in wrapped
out = f(*tensors)
File "<string>", line 1, in <lambda>
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/passes/functorch.py", line 65, in wrapped
out = function(*args_functional)
File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/builtins/jittable.py", line 202, in flat_wrapped_f
return self.wrapped_f(*pytorch_args, **pytorch_kwargs)
File "/home/avinash/nod/SHARK-Turbine/python/turbine_models/custom_models/sd_inference/vae_encode.py", line 30, in forward
sample = torch.randn(inp.shape, generator=generator, device="cpu")
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 555, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 580, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 361, in proxy_call
out = func(*args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in dispatch
op_impl_out = op_impl(self, func, *args, **kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 422, in constructors
r = func(*args, **new_kwargs)
File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:3404: SymIntArrayRef expected to contain only concrete integers
import torch
import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from typing import Optional
from shark_turbine.aot import *
from iree.compiler.ir import Context
class SampleModel(CompiledModule):
def run_forward(self, inp=AbstractTensor(1,4,64,64, dtype=torch.float32)):
sample = self.forward(inp)
return sample
@jittable
def forward(inp) -> torch.FloatTensor:
sample = torch.randn(inp.shape, device="cpu")
# make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device="cpu", dtype=torch.float32)
return sample
sample_model = SampleModel(context=Context(), import_to="IMPORT")
print(str(CompiledModule.get_mlir_module(sample_model)))
seems to fix your issue. Can you model the full thing this way?
import shark_turbine.aot as aot
import torch
import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from typing import Optional
from diffusers import AutoencoderKL
from shark_turbine.aot import *
from iree.compiler.ir import Context
hf_model_name = "CompVis/stable-diffusion-v1-4"
vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
class VaeModel(CompiledModule):
def run_forward(self, inp=AbstractTensor(1, 3, 512, 512, dtype=torch.float32)):
x = self.forward(inp)
return x
@jittable
def forward(inp) -> torch.FloatTensor:
latents = vae.encode(inp).latent_dist.sample()
return 0.18215 * latents
vae_model = VaeModel(context=Context(), import_to="IMPORT")
print(str(CompiledModule.get_mlir_module(vae_model)))
I think the issue might be when the jittable calls another function that needs to use the input shape directly. Seems like this example also fails the SymIntArrayRef check.
Sync with Ean