patil-suraj / vit-vqgan

JAX implementation ViT-VQGAN
MIT License
77 stars 11 forks source link

Support LPIPS when using LAB colorspace. #18

Open pcuenca opened 2 years ago

pcuenca commented 2 years ago

Addresses #12.

I also double-checked that the input for LPIPS are images in the range [-1, 1], everything looks correct in that regard.

borisdayma commented 2 years ago

Can you call the function to_rgb (or similar) and make it a noop when the format is already RGB?

borisdayma commented 2 years ago

Also can you check if it affects the prepare_dataset notebook and the wandb logging (may need to use a jax.device_get there).

pcuenca commented 2 years ago

Can you call the function to_rgb (or similar) and make it a noop when the format is already RGB?

I didn't want to call it to_rgb because then it looks semantically similar to logits_to_image, but images are returned in [0, 1] and this one works in [-1, 1]. But maybe it's clearer that way.

borisdayma commented 2 years ago

Can you call the function to_rgb (or similar) and make it a noop when the format is already RGB?

I didn't want to call it to_rgb because then it looks semantically similar to logits_to_image, but images are returned in [0, 1] and this one works in [-1, 1]. But maybe it's clearer that way.

Maybe to_lpips? Pretty close to what you had then.

pcuenca commented 2 years ago

Also can you check if it affects the prepare_dataset notebook and the wandb logging (may need to use a jax.device_get there).

Sure, good point. But this was branched from main not stylegan. I'll check if there are any changes needed and if so maybe it makes sense to submit as a PR to stylegan instead.

pcuenca commented 2 years ago

I addressed a couple of @borisdayma's comments, but it turns out we can't use tfio.experimental.color.lab_to_rgb in jitted code because there's an implicit conversion to numpy here.

I checked the color conversion functions in dm_pix (https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/color_conversion.py), but they don't have LAB support yet.

We can port TensorFlow's implementation to JAX, but maybe we have more important things to do first. What do you think @borisdayma, @patil-suraj ?