mlc-ai / web-stable-diffusion

Bringing stable diffusion models to web browsers. Everything runs inside the browser with no server support.
https://mlc.ai/web-stable-diffusion
Apache License 2.0
3.56k stars 227 forks source link

Cannot build stable diffusion model: "BackendCompilerFailed: backend='_capture' raised AssertionError" #24

Open loicmagne opened 1 year ago

loicmagne commented 1 year ago

I tried building the stable diffusion model using the walkthrough.ipynb notebook or the build.py file, but when I run the "Combine every piece together" part :

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
clip = clip_to_text_embeddings(pipe)
unet = unet_latents_to_noise_pred(pipe, torch_dev_key)
vae = vae_to_image(pipe)
concat_embeddings = concat_embeddings()
image_to_rgba = image_to_rgba()
schedulers = [
    dpm_solver_multistep_scheduler_steps(),
    trace.PNDMScheduler.scheduler_steps()
]

mod: tvm.IRModule = utils.merge_irmodules(
    clip,
    unet,
    vae,
    concat_embeddings,
    image_to_rgba,
    *schedulers,
)

Both results in the same error:

│ /usr/local/lib/python3.10/dist-packages/torch/__init__.py:1565 in __call__                       │
│                                                                                                  │
│   1562 │   │   │   │   self.dynamic == other.dynamic)                                            │
│   1563 │                                                                                         │
│   1564 │   def __call__(self, model_, inputs_):                                                  │
│ ❱ 1565 │   │   return self.compiler_fn(model_, inputs_, **self.kwargs)                           │
│   1566                                                                                           │
│   1567                                                                                           │
│   1568 def compile(model: Optional[Callable] = None, *,                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/torch/dynamo.py:151 in _capture       │
│                                                                                                  │
│   148 │   def _capture(graph_module: fx.GraphModule, example_inputs):                            │
│   149 │   │   assert isinstance(graph_module, torch.fx.GraphModule)                              │
│   150 │   │   input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inp   │
│ ❱ 151 │   │   mod_ = from_fx(                                                                    │
│   152 │   │   │   graph_module,                                                                  │
│   153 │   │   │   input_info,                                                                    │
│   154 │   │   │   keep_params_as_input=keep_params_as_input,                                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/torch/fx_translator.py:1387 in        │
│ from_fx                                                                                          │
│                                                                                                  │
│   1384 │   to print out the tabular representation of the PyTorch module, and then               │
│   1385 │   check the placeholder rows in the beginning of the tabular.                           │
│   1386 │   """                                                                                   │
│ ❱ 1387 │   return TorchFXImporter().from_fx(                                                     │
│   1388 │   │   model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_retur  │
│   1389 │   )                                                                                     │
│   1390                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/torch/fx_translator.py:1282 in        │
│ from_fx                                                                                          │
│                                                                                                  │
│   1279 │   │   │   │   │   │   self.env[node] = self.convert_map[node.target](node)              │
│   1280 │   │   │   │   │   else:                                                                 │
│   1281 │   │   │   │   │   │   raise ValueError(f"Unsupported op {node.op}")                     │
│ ❱ 1282 │   │   │   assert output is not None                                                     │
│   1283 │   │   │   self.block_builder.emit_func_output(output)                                   │
│   1284 │   │                                                                                     │
│   1285 │   │   mod = self.block_builder.get()                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: backend='_capture' raised:
AssertionError: 

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

It seems there is a problem with TorchDynamo

Also a somewhat unrelated error, but I couldn't get to install the CUDA version of the mlc/tvm package:

!python3 -m pip install mlc-ai-nightly-cu116 -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
ERROR: Could not find a version that satisfies the requirement mlc-ai-nightly-cu116 (from versions: none)
ERROR: No matching distribution found for mlc-ai-nightly-cu116

Both errors can be reproduced by running the notebook on google colab