In the current stable diffusion code example, it supports model CompVis/stable-diffusion-v1-4-flax.
Request to add a from_pt bool arg and pass in CLIPTokenizer.from_pretrained, FlaxCLIPTextModel.from_pretrained, FlaxAutoencoderKL.from_pretrained, and FlaxUNet2DConditionModel.from_pretrained to enable conversion from PyTorch model. So we could use stabilityai/stable-diffusion-2-1 in flax.
Describe the bug
In the current stable diffusion code example, it supports model
CompVis/stable-diffusion-v1-4-flax
.Request to add a
from_pt
bool arg and pass inCLIPTokenizer.from_pretrained
,FlaxCLIPTextModel.from_pretrained
,FlaxAutoencoderKL.from_pretrained
, andFlaxUNet2DConditionModel.from_pretrained
to enable conversion from PyTorch model. So we could usestabilityai/stable-diffusion-2-1
in flax.Reproduction
Successful run example on TPU in my forked repo (https://github.com/RissyRan/diffusers.git):
Logs
System Info
Google Cloud TPU
Who can help?
@yiyixuxu @patrickvonplaten