Open krahnikblis opened 1 year ago
Hey @krahnikblis,
That's very interesting, just to understand better whether this is a general problem or just a Flax problem. Did you experience the same when using the PyTorch img2img pipeline (e.g.: https://huggingface.co/docs/diffusers/using-diffusers/img2img) ?
Also cc @pcuenca @patil-suraj
@patrickvonplaten no... well i never had to do any custom transformations for input that i didn't for output anyway. with torch stuff, i'm just plugging things together that were already developed to be operable so i don't look under the hood as much, so now let's...
opening up my notebook where i used the torch img2img, it looks like it takes PIL both in and out. i guess my perception is that the img2img would take as inputs the same format/type that text2img training would take (i.e. they both need captions and images). BUT, taking a peek in one of your existing torch training scripts, i see that the image data loader is doing a bunch of transformation gymnastics (from PIL to numpy to PIL to numpy to torch - lines 363-397 lolz), i.e. the torch text2img training does not (directly) take the same PIL input that the img2img takes, so perhaps my perception/expectation is off? and, line 395 in that training script, image / 127.5 - 1.0
is setting up the tensor format to be above and below 0, which is what my own * 2 - 1
adjustment shows is expected as input...
basically, i'm replicating some of the neat new stuff that's available in torch, in Jax, specifically to work well on TPUs (e.g. text inversion, low rank approximation/extraction (btw, given y'all have like PBs of people's fine-tuned models on your site, i bet storage costs are high for your company - y'all might reaaaaallly like the cost savings and end-user setup-time reduction enabled by this over here e.g. i've low-ranked like 15 of the models on the HF site down to ~20MB each, vs their ~8GB folder sizes))... and for the needs of my workflow i'm operating under that expectation that the text2img training and the img2img pipe could use the same image/caption loader. but the img2img pipe has this issue on the loader that works for the text2img training. and, comparing again to the torch img2img which uses PIL both in and out, i guess my own humble expectation would be that the Jax img2img would also have same formats in and out.
but, i could totally be wrong here - i guess i'm the only one having this issue or working it from this angle, since it was developed, tested, and released - i.e. no one else experienced whitewashed images... so... ¯\_(ツ)_/¯
Maybe cc @pcuenca @patil-suraj here - are you able to reproduce this problem on TPU in JAX?
If I understand this correctly, your question is, why is the image format different for input/training and output.
This is because, during training, images are normalized so that they have a uniform range, which helps the model train better. A simple google search for why image normalization is needed will show a lots blogs/answers which explain it in detail. Normalizing between [-1, 1] is a very common practise, and this is how stable diffusion is trained. That's why when inputting the images into the model during inference, we normalise them. But that's not necessary for the output. For the output, we want the actual values between 0-1, so they can be converted to the image.
Hope this helps.
@patil-suraj thanks... it was more a question of "did y'all mean to leave it that way" i.e. have a different format go in than come out, or was it meant to convert a standard np format image to a [-1,1] ranged image inside of the pipe where the user doesn't need to deal with it.
example workflow: use text2img to generate something, then use img2img to continue to refine it, then use inpaint to adjust only a small remaining imperfection. in this scenario, one would expect the output of the text2img pipe to be directly compatible with the input to img2img, but as-is today, the image must be adjusted to a different range of values. it's not hard, but it's a step that feels unnatural and that some users might not know to do. the rest of my commentary re formats is more my own bemusement that there is no single consistent way of handling image data in general, but not of material importance to my question.
Could you provide a code snippet explaining the issue, that would help better understand it?
ok, here you go:
##### img2img in/out diff proof
# installs and TPU setup
!pip install -qq -U jax jaxlib
!pip install -qq flax optax transformers ftfy diffusers
### sometimes requires restart at this point, because google.
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu(tpu_driver_version='tpu_driver_nightly')
import os
os.environ["USE_FLAX"] = "1" # suddenly might be needed to use transformers flax version ?
os.environ["XLA_USE_BF16"] = "1" # XLA TPU OOM error, Colab TPU has 8x8GB High-Bandwidth Memory (HBM)
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
# imports
import PIL, jax, requests
import jax.numpy as jnp
import numpy as np
from diffusers import FlaxStableDiffusionImg2ImgPipeline
from flax import jax_utils
# get model
pretrained_model_hf_user = "runwayml" #@param {type:"string"}
pretrained_model_name = "stable-diffusion-v1-5" #@param {type:"string"}
pretrained_model_ver = "flax" #@param {type:"string"} ["","bf16", "fp16", "flax"]
pretrained_model_name_or_path = pretrained_model_name + "/"
url = f"https://huggingface.co/{pretrained_model_hf_user}/{pretrained_model_name}"
if pretrained_model_ver != "":
url = f"-b {pretrained_model_ver} " + url
if not os.path.exists(pretrained_model_name_or_path):
!git lfs install
!git clone {url}
# model setup - i normally build components separately but whatevs this works too
img2img, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(pretrained_model_name_or_path, dtype=jnp.bfloat16)
for key,val in img2img.components.items():
print(key)
if key in ("tokenizer","scheduler","feature_extractor"):
continue
params[key] = val.to_bf16(params[key])
# get a sample image, in np format range 0 to 1, same as output of text2img (flax) pipe
image = jnp.asarray(PIL.Image.open(requests.get("https://huggingface.co/nitrosocke/Arcane-Diffusion/resolve/main/magical_princess.png", stream=True).raw)).transpose((2,0,1)) / 255
# make a handy display fn:
def display_jnp_image(image):
image = jnp.clip(np.array(image * 255),0,255).astype(np.uint8)
if image.shape[0] == 3:
image = np.transpose(image, axes=(1, 2, 0))
image = PIL.Image.fromarray(np.array(image),mode="RGB")
display(image)
# show base image:
display_jnp_image(image)
# use image for img2img without changing its range
prompt = "a painting of a lonely princess"
prompt_ids = img2img.tokenizer(prompt,padding="max_length",max_length=img2img.tokenizer.model_max_length,truncation=True,return_tensors="np").input_ids
rng = jax.random.PRNGKey(42)
here's the raw/original image: here's using the image in normal np format as input to img2img, and a look at whitewashed outputs
### repeat this test for regular value of image and [-1,1] ranged one...
images = img2img( ### jax_batch is a batched conversion from a tensorflow data generator, reading image files and captions.
prompt_ids=jax_utils.replicate(prompt_ids),
image=jax_utils.replicate(jnp.expand_dims(image,axis=0)), # extra dim for batch size
params=jax_utils.replicate(params),
prng_seed=jax.random.split(rng,jax.device_count()),
strength=0.2, # default = 0.8, problem more observable the lower this is (i.e. more like orig)
num_inference_steps=50,
guidance_scale=jax_utils.replicate(jnp.asarray([5],dtype=jnp.float32)), # error with args.dtype i.e. jnp.blfoat16
jit=True
).images
for img in images:
print(img.min(), img.max()) ### proves output is range 0 to 1
display_jnp_image(img.squeeze())
sample print of range and a whitewashed output image: 0.322266 1 that print proved that img2img output has a value range between 0 and 1. now here's adjusting the input to have value range -1 to 1, but which gives expected output of appearance of an image
### results were whitewashed, now alter input image to be ranged [-1,1], run test, see results
image = image * 2 - 1 # changes range from [0,1] to [-1,1]
images = img2img( ### jax_batch is a batched conversion from a tensorflow data generator, reading image files and captions.
prompt_ids=jax_utils.replicate(prompt_ids),
image=jax_utils.replicate(jnp.expand_dims(image,axis=0)), # extra dim for batch size
params=jax_utils.replicate(params),
prng_seed=jax.random.split(rng,jax.device_count()),
strength=0.2, # default = 0.8, problem more observable the lower this is (i.e. more like orig)
num_inference_steps=50,
guidance_scale=jax_utils.replicate(jnp.asarray([5],dtype=jnp.float32)), # error with args.dtype i.e. jnp.blfoat16
jit=True
).images
for img in images:
print(img.min(), img.max()) ### proves output is range 0 to 1
display_jnp_image(img.squeeze())
0 1 note, the value range is 0 to 1, even though input is -1 to 1.
Thanks a lot for the code snippet. Also cc @pcuenca @patrickvonplaten For flax pipelines that accept images maybe we should do the pre-processing inside, also related #2061
I agree with @krahnikblis that the pipeline should produce outputs in the same range as expected inputs, it's confusing and prone to error that inputs are in [-1, 1]
and outputs are in [0, 1]
. We also need to document it.
What should be the range to use? My vote would be for [-1, 1]
for both input and output.
Note that this will be a breaking change, of course.
Interesting! In PyTorch the output is between [-1, 1] no? So this looks a bit like a bug to me.
Also in favor of forcing the output to be in [-1, 1] range.
Thanks a lot for the nice investigation @krahnikblis !
my vote would be to use the [0,1] range, since that's already what's long been the output of text2img pipe, and i think that's the standard way of np array images when using the float type (couldn't find better docs fast, but this page shows using [0,255] and [0,1] methods). selfishly motivated too, i've already done all my stuff around np images being [0,1] :-P
my understanding of the [-1,1] type is that it represents a normalized distance from the mean value, that mean value being something i've seen hard-coded in a lot of places along with a pytorch F.normalize() call. i.e., the [-1,1] method isn't really a true image array, but rather a normalized and 0-centered representation of it. i always found that the most confusing to use (grains of salt here; pretty much everything pytorch confuses me)
I agree with @krahnikblis , think it would be good to accept image arrays without assuming any pre-processing, and instead do the normalization inside the pipelines. So users could just decode the PIL image into an array and pass it to the pipeline.
Yeah sorry I thought that PT outputs between [-1, 1] but PyTorch outputs tensors in [0, 1]:
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
images = pipeline("example", output_type="np").images
So fully agree with @krahnikblis and @patil-suraj here => All PT and Flax pipelines that accept images should accept images in the range [0, 1].
If I look here: https://github.com/huggingface/diffusers/blob/2f9a70aa852d84adaa17a824454bf6d28180b55a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L89 one can see that currently when passing a pytorch tensor to the img2img pipeline the range [-1, 1] is expected even in PyTorch which is wrong IMO. => We should deprecate inputs is the range [-1, 1] and then process tensors to [-1, 1] from [0, 1] here: https://github.com/huggingface/diffusers/blob/2f9a70aa852d84adaa17a824454bf6d28180b55a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L89
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@yiyixuxu this issue is very much relevant for #2304 and should be solved by the PR
Describe the bug
when running img2img pipeline, the outputs are all too light. i.e. nothing darker than middle gray. it's not very pronounced when using a higher/default "strength" value, but for use cases akin to style transfer where only minor noise is added to original image, it's very apparent. in other words, the less the image is meant to be changed, the more light the output appears.
this felt like ye olde torch copypasta problem, wherein pytorch tensors for images like to be in a range of -1 to 1, PIL uses 0 to 255 and numpy uses 0 to 1. or something, i'm not sure, i don't know why no one keeps to a single standard and we always swap channels and convert value ranges... anyway, i tested it and indeed things work correctly when i alter my inputs for pixel_values to be in a value range of -1 to 1 instead of what numpy arrays are using everywhere, 0 to 1. however, the outputs of the pipeline are always in the range 0 to 1, i.e., it expects -1 to 1 inputs, but it gives back 0 to 1 outputs.
expected behavior is that it would take the same format in as it gives out. in the jax/numpy world, that should be ranged 0 to 1.
Reproduction
i'm running this on Colab with TPU, so sharding/replication is done accordingly below...
swap the commented line for image argument. my input array for pixel_values is ranged 0 to 1, so the extra *2-1 makes it -1 to 1.
Logs
No response
System Info
colab TPU, high-ram, latest versions of transformers and diffusers and flax and optax.