NVlabs / stylegan3

Official PyTorch implementation of StyleGAN3
Other
6.3k stars 1.11k forks source link

Implementing training in HSV color space? #210

Open Kaoru8 opened 1 year ago

Kaoru8 commented 1 year ago

I've been thinking for a while now about the fact that most image-based models operate in RGB color space. While it makes sense from a practicality point of view, since it's the most common format and thus requires little or no preprocessing/conversion, it has some properties that make it a less-than-ideal choice for training a neural network. Interpolation through the RGB space is "unnatural" in terms of human perception - points mathematically close to each other in the color space can be pretty far apart/different in terms of their perceived similarity. While human perception obviously isn't the same as a neural network's, it stands to reason that a smoother, more natural color interpolation curve would be compatible with the whole concept of gradient descent, and could result in faster training/convergence - especially in datasets with predominantly human-created images (artwork, design, fashion, etc.) where human color perception is very much a factor in color distribution in the training images.

All of that in mind, I set out to convert StyleGAN3 to operate in the HSV color space. It seemed like a trivial task at first - after loading the RGB images as usual, just convert to HSV and normalize to the -1,1 range as you would RGB images. The network is then effectively training in HSV, and when synthesizing we just need to do HSV -> RGB after de-normalizing. Of course, it ended up being not quite that simple...

I'm pretty confident in doing any necessary conversion/normalization changes that need to be done to training/dataset.py and training/training_loop.py, but things get more complicated from there. Since different color spaces were never a concern, RGB is assumed and hardcoded in quite a few places including the core network architecture(s), with toRGB layers, fromRGB functions, etc. And I definitely don't have the confidence to make changes to the architecture and also know that what I'm doing actually makes sense for what I'm trying to accomplish, so hoping someone more knowledgeable can chime in...

How difficult would this actually be to implement, and would it require architecture changes, or can I just assume that everything will work as expected provided I just normalize network inputs to the same range?