huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.08k stars 5.37k forks source link

Schedulers not compatible with OnnxStableDiffusionPipeline: TypeError: unsupported operand #967

Closed Tnifey closed 2 years ago

Tnifey commented 2 years ago

Describe the bug

Hi, I tried to use different schedulers with OnnxStableDiffusionPipeline, but it throw errors. Schedulers are not compatible with numpy used in onnx pipeline.

Onnx checkpoints converted with: convert_stable_diffusion_checkpoint_to_onnx.py

I have found a solution, but it is probably not optimal, because torch usage inside pipeline call.

Solution works with:

First error:

File "C:\...\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py", line 152, in __call__
  latents = latents * self.scheduler.init_noise_sigma
TypeError: unsupported operand type(s) for *: 'numpy.ndarray' and 'Tensor'

If I cast latents to torch.tensor, before init_noise_sigma:

latents = torch.tensor(latents)
latents = latents * self.scheduler.init_noise_sigma
File "C:\...\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py", line 171, in __call__
  noise_pred = self.unet(
File "C:\...\diffusers\onnx_utils.py", line 46, in __call__
  return self.model.run(None, inputs)
File "C:\...\onnxruntime\capi\onnxruntime_inference_collection.py", line 200, in run
  return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (tensor(double)) , expected: (tensor(int64))

If I add dtype=np.int64 to timestep in unet args:

# predict the noise residual
noise_pred = self.unet(
  sample=latent_model_input,
  timestep=np.array([t], dtype=np.int64),
  encoder_hidden_states=text_embeddings,
)
File "C:\...\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py", line 184, in __call__
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
File "C:\...\diffusers\schedulers\scheduling_lms_discrete.py", line 224, in step
  pred_original_sample = sample - sigma * model_output
TypeError: unsupported operand type(s) for -: 'numpy.ndarray' and 'Tensor'

And if I cast latents to torch.tensor, before passing them into scheduler.step, it is working again:

# compute the previous noisy sample x_t -> x_t-1
latents = torch.tensor(latents)
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = np.array(latents)

Reproduction

from diffusers import OnnxStableDiffusionPipeline, LMSDiscreteScheduler

lms = LMSDiscreteScheduler()

pipe = OnnxStableDiffusionPipeline.from_pretrained(
    model_path,
    provider="DmlExecutionProvider",
    scheduler=lms,
    local_files_only=True,
)

image = pipe("prompt")[0]

Logs

No response

System Info

patrickvonplaten commented 2 years ago

@anton-l could you take a look here?

averad commented 2 years ago

@anton-l @patrickvonplaten only scheduler that currently works out of the box with the Onnx Pipeline is PNDMScheduler. Is there a fix for this in Diffusers 0.7.0 beta or any updates?

Amazing work getting the Onnx pipeline working with Img2Img as well as Inpainting in Diffusers >= 0.6.0

anton-l commented 2 years ago

Sorry for missing this, will be fixed in 0.7.1 in a couple of days!

averad commented 2 years ago

@anton-l you and your team members are the best, thank you.

averad commented 2 years ago

@anton-l I installed diffusers 0.7.1 and tested the following schedulers using the Onnx pipeline

from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, EulerDiscreteScheduler
scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 
scheduler = EulerDiscreteScheduler.from_config('./stable_diffusion_v1-5', subfolder="scheduler")

Error PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler results in error

DDIMScheduler reults in error

Log (PNDMScheduler)

Traceback (most recent call last):
  File "D:\ai\txt_to_img.py", line 130, in <module>
    txt_to_img(prompt, negative_prompt, int(num_inference_steps), int(width), int(height), seed)
  File "D:\ai\txt_to_img.py", line 79, in txt_to_img
    image = pipe(
  File "D:\ai\sd_env\lib\site-packages\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py", line 206, in __call__
    noise_pred = self.unet(
  File "D:\ai\sd_env\lib\site-packages\diffusers\onnx_utils.py", line 46, in __call__
    return self.model.run(None, inputs)
  File "D:\ai\sd_env\lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (tensor(double)) , expected: (tensor(int64))

Log (DDIMScheduler)

Traceback (most recent call last):
  File "D:\ai\txt_to_img.py", line 130, in <module>
    txt_to_img(prompt, negative_prompt, int(num_inference_steps), int(width), int(height), seed)
  File "D:\ai\txt_to_img.py", line 79, in txt_to_img
    image = pipe(
  File "D:\ai\sd_env\lib\site-packages\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py", line 217, in __call__
    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
  File "D:\ai\sd_env\lib\site-packages\diffusers\schedulers\scheduling_ddim.py", line 263, in step
    pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
TypeError: unsupported operand type(s) for -: 'numpy.ndarray' and 'Tensor'

Log (LMSDiscreteScheduler)

Traceback (most recent call last):
  File "D:\ai\txt_to_img.py", line 130, in <module>
    txt_to_img(prompt, negative_prompt, int(num_inference_steps), int(width), int(height), seed)
  File "D:\ai\txt_to_img.py", line 79, in txt_to_img
    image = pipe(
  File "D:\ai\sd_env\lib\site-packages\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py", line 206, in __call__
    noise_pred = self.unet(
  File "D:\ai\sd_env\lib\site-packages\diffusers\onnx_utils.py", line 46, in __call__
    return self.model.run(None, inputs)
  File "D:\ai\sd_env\lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (tensor(double)) , expected: (tensor(int64))

Log (EulerDiscreteScheduler)

Traceback (most recent call last):
  File "D:\ai\test.py", line 44, in <module>
    image = pipe.text2img(prompt, negative_prompt = neg_prompt, width = width, height = height, num_inference_steps = num_inference_steps, latents = latents, max_embeddings_multiples=3).images[0]
  File "D:\ai\sd_env\lpw_stable_diffusion_onnx.py", line 789, in text2img
    return self.__call__(
  File "D:\ai\sd_env\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "D:\ai\sd_env\lpw_stable_diffusion_onnx.py", line 651, in __call__
    noise_pred = self.unet(
  File "D:\ai\sd_env\lib\site-packages\diffusers\onnx_utils.py", line 46, in __call__
    return self.model.run(None, inputs)
  File "D:\ai\sd_env\lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (tensor(double)) , expected: (tensor(int64))
anton-l commented 2 years ago

@averad https://github.com/huggingface/diffusers/pull/1173 should fix the scheduler issues once merged