Closed entrpn closed 1 year ago
Thanks a lot for the report, and sorry for the trouble @entrpn!
Workaround:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
revision="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
feature_extractor=None,
)
The proper solution would be to replicate something like #1395 in pipeline_flax_utils.py
, which I suspect would take a little bit of effort to get right before merge. Will iterate on it in the next few days! /cc @patrickvonplaten
That you @pcuenca, the workaround works for me.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Describe the bug
When trying to load Stable Diffusion 2.1 using Flax, I am getting the following error:
Reproduction
Create a TPU VM and run the following installation:
The run the following as follows:
python infer.py --sd-version 2 --itters 3
Logs
System Info
diffusers
version: 0.22.0.dev0Who can help?
No response