siliconflow / onediff

OneDiff: An out-of-the-box acceleration library for diffusion models.
https://github.com/siliconflow/onediff/wiki
Apache License 2.0
1.61k stars 99 forks source link

Support for custom attention processor? #873

Closed Davin05 closed 4 months ago

Davin05 commented 4 months ago

Hi,

Thanks for making onediff possible. I am experimenting with my own custom attention processor, which I integrate into diffuser based unet via set_attn_processor:

unet.set_attn_processor(customized_attn_processors_dicts

However, when I compile my model with oneflow_compile, it always returned a NotImplementedError. The detail error message attached below:

ERROR:    Traceback (most recent call last):
  .
  .
  .
  File "/venv/lib/python3.10/site-packages/diffusers/models/attention.py", line 329, in forward
    attn_output = self.attn1(
  File "/venv/lib/python3.10/site-packages/oneflow/nn/graph/proxy.py", line 188, in __call__
    result = self.__block_forward(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/oneflow/nn/graph/proxy.py", line 238, in __block_forward
    result = unbound_forward_of_module_instance(self, *args, **kwargs)
  File "/venv/lib/python3.10/site-packages/infer_compiler_registry/register_diffusers/attention_processor_oflow.py", line 363, in forward
    return self.processor(
  File  "/venv/lib/python3.10/site-packages/oneflow/nn/graph/proxy.py", line 188, in __call__
    result = self.__block_forward(*args, **kwargs)
  File  "/venv/lib/python3.10/site-packages/oneflow/nn/graph/proxy.py", line 238, in __block_forward
    result = unbound_forward_of_module_instance(self, *args, **kwargs)
  File  "/venv/lib/python3.10/site-packages/oneflow/nn/modules/module.py", line 200, in forward
    raise NotImplementedError()
NotImplementedError

Is there a way that I can use a custom attention processor and onediff together? Thanks!

Another (perhaps important) detail is that my Attention Processor is a subclass of nn.module, that probably matter?

Davin05 commented 4 months ago

I think I found the issue, a subclass of nn.module requires implementation of forward in order for onediff to compile