huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.41k stars 5.43k forks source link

Type hint for `callback_on_step_end` in pipeline `__call__` is incorrect #6699

Open philpax opened 10 months ago

philpax commented 10 months ago

Describe the bug

As far as I can tell, the type hint callback_on_step_end for all pipelines (as seen in this search) is incorrect.

Taking the SDXL pipeline as an example, the type hint in __call__ is:

callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None

but it is called like this:

if callback_on_step_end is not None:
    callback_kwargs = {}
    for k in callback_on_step_end_tensor_inputs:
        callback_kwargs[k] = locals()[k]
    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

    latents = callback_outputs.pop("latents", latents)
    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
    negative_pooled_prompt_embeds = callback_outputs.pop(
        "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
    )
    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
    negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)

That is, the type hint suggests that it's a function with three arguments that returns nothing, but it's actually a function with four arguments that returns a Dict. This ends up failing at runtime. Using a four-argument function leads to type-checker errors.

Reproduction

Using a Python typechecker (Pyright in my case), attempt to use a correctly-defined callback with a SDXL pipeline:


sdxl_pipe = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16"
)
sdxl_pipe = typing.cast(StableDiffusionXLPipeline, sdxl_pipe)

def callback_on_step_end(_pipe, step, _timestep, kwargs):
    print(step)
    return kwargs

output = sdxl_pipe(
    prompt="a prompt",
    negative_prompt=self.negative_prompt,
    num_inference_steps=4,
    guidance_scale=0.0,
    return_dict=True,
    width=512,
    height=512,
    callback_on_step_end=callback_on_step_end,
)

The type-checker will error on callback_on_step_end:

Argument of type "(_pipe: Unknown, step: Unknown, _timestep: Unknown, kwargs: Unknown) -> Unknown" cannot be assigned to parameter "callback_on_step_end" of type "((int, int, Dict[Unknown, Unknown]) -> None) | None" in function "__call__"
  Type "(_pipe: Unknown, step: Unknown, _timestep: Unknown, kwargs: Unknown) -> Unknown" cannot be assigned to type "((int, int, Dict[Unknown, Unknown]) -> None) | None"
    Type "(_pipe: Unknown, step: Unknown, _timestep: Unknown, kwargs: Unknown) -> Unknown" cannot be assigned to type "(int, int, Dict[Unknown, Unknown]) -> None"
      Function accepts too few positional parameters; expected 4 but received 3
    "function" is incompatible with "None"

This can be worked around using # type: ignore, which is what I'm doing.

Logs

No response

System Info

N/A

Who can help?

No response

yiyixuxu commented 10 months ago

Thanks for creating this issue do you want to open a PR to fix the type hint?

philpax commented 10 months ago

Hi, sorry for the late response! Yes, I can take a look at it; should hopefully be a relatively straightforward fix.

Edit: I'm not currently working with diffusers, so I haven't been able to work on this fix. Free for anyone else to take it.

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

a-r-r-o-w commented 1 week ago

Thanks for reporting, I will do a type hint sprint soon and try taking this into account