Open Miladiouss opened 6 years ago
Thanks for the issue!
I think we could provide a histogram transformation functionality in torchvision. Maybe one possibility would be to allow the user to pass in directly the target histogram, instead of passing the image, and provide a simple functionality to compute the histogram of an image.
Also, apart from np.interp
, all the other functions have torch
equivalents, so maybe we could make it use torch
functions whenever possible?
Also, could you send a PR?
Thank you @fmassa for your reply. The function above was written outside the context of PyTorch. That's why it's all NumPy. If you are going to integrate my code, I need some time (~ 1 month) to write it properly. As you said, I will use torch
functions whenever possible and pass a target histogram instead. That's easy.
But the user should be able to
My biggest worry is applying a histogram transform on floats. I need to see how other people do this (if you know any references, please send it my way). Otherwise, the way I would do it is to divide the input image/tensor to $b$ ranges where $b$ equals to the number of bins of the provided histogram. Then we need to truncate values by choosing min
/max
values. These limits can be found by finding the location where a small portion $p$ of the data falls below/above those values. By implementing this we could address image processing for non-standard formats such as RAW in photography and FITS in astronomy, as well as sound files.
By putting more effort into it, we could provide template histograms eliminating the need for gamma correction, contrast adjustments, and mean and standard deviation shifting.
Do you think I should send the pull request before addressing the big issues above or after (I'm still new to the GitHub community)?
About points 1-2: what about the following: if the user provides a 1d histogram, it performs the same equalization for all channels (as if it was broadcast for all channels), and if it is a 2d histogram, then each channel uses one of the provided histograms. One of the limitations of this approach is that the number of elements of all histograms should be the same, but this is usually fine for uint8 images.
If we pass the bins of the histogram as well, we don't need to worry if the image is floating point or not, so we would pass not only the counts but also what value each count accounts for. I think this would solve your comments, right?
I don't have any experience with RAW or FITS images, but please feel free to send a PR with what we have discussed. Also, raising any issues you might see with it is definitely valuable!
I like that! Then I'll start writing the function hist_transform(input_tensor, hist_bin, hist_count)
, all torch tensors and inform you when done.
Hello @fmassa again. So I finished polishing and testing the code as we discussed. It is very inconvenient to switch back and forth between torch.tensor and numpy.ndarray, so I decided to do everything in numpy.
I am very new to git and GitHub and I don't know how properly do a pull request. If you can help me with that, I'll appreciate it. See this gist for the code which includes a module test. So, the part that includes the major code is above # For tests and demonstration
.
Hey 😃
I was browsing throught the vison issues and found that one, turns out I actually did some work on histogram specification some time ago. Something like that :
I wrote it as a cuda module as I was running the transform in an optimisation loop and needed it to be fast. The code is available over here if that can be useful : https://github.com/pierre-wilmot/NeuralTextureSynthesis/ Happy to help cleaning it up if you think it's worth adding to the vision repo.
@gheaeckkseqrz Thanks for the proposal! I think it could be a nice addition to have efficient GPU-accelerated transforms, but first we need to have a reference implementation and I have to find the time to review the PR in #796
@gheaeckkseqrz's proposal would also really help with style transfer and potentially GAN related tasks as well. Support for histogram matching on GPU with tensors can be extremely useful for style transfer. https://github.com/ProGamerGov/neural-style-pt/issues/46#issuecomment-563005587
@Miladiouss So, I created this function that essentially matches the histogram of one image to another image, and it should hopefully help individuals with use cases like astronomy & neural style transfer.
I wrote the code for a different PyTorch project (pytorch/captum), but Torchvision is free to use it as well! @fmassa
def color_transfer(
input: torch.Tensor,
source: torch.Tensor,
mode: str = "pca",
eps: float = 1e-5,
) -> torch.Tensor:
"""
Transfer the colors from one image tensor to another, so that the target image's
histogram matches the source image's histogram. Applications for image histogram
matching includes neural style transfer and astronomy.
The source image is not required to have the same height and width as the target
image. Batch and channel dimensions are required to be the same for both inputs.
Gatys, et al., "Controlling Perceptual Factors in Neural Style Transfer", arXiv, 2017.
https://arxiv.org/abs/1611.07865
Args:
input (torch.Tensor): The NCHW or CHW image to transfer colors from source
image to from the source image.
source (torch.Tensor): The NCHW or CHW image to transfer colors from to the
input image.
mode (str): The color transfer mode to use. One of 'pca', 'cholesky', or 'sym'.
Default: "pca"
eps (float): The desired epsilon value to use.
Default: 1e-5
Returns:
matched_image (torch.tensor): The NCHW input image with the colors of source
image. Outputs should ideally be clamped to the desired value range to
avoid artifacts.
"""
assert input.dim() == 3 or input.dim() == 4
assert source.dim() == 3 or source.dim() == 4
input = input.unsqueeze(0) if input.dim() == 3 else input
source = source.unsqueeze(0) if source.dim() == 3 else source
assert input.shape[:2] == source.shape[:2]
# Handle older versions of PyTorch
torch_cholesky = (
torch.linalg.cholesky if torch.__version__ >= "1.9.0" else torch.cholesky
)
def torch_symeig_eigh(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
torch.symeig() was deprecated in favor of torch.linalg.eigh()
"""
if torch.__version__ >= "1.9.0":
L, V = torch.linalg.eigh(x, UPLO="U")
else:
L, V = torch.symeig(x, eigenvectors=True, upper=True)
return L, V
def get_mean_vec_and_cov(
x_input: torch.Tensor, eps: float
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert input images into a vector, subtract the mean, and calculate the
covariance matrix of colors.
"""
x_mean = x_input.mean(3).mean(2)[:, :, None, None]
# Subtract the color mean and convert to a vector
B, C = x_input.shape[:2]
x_vec = (x_input - x_mean).reshape(B, C, -1)
# Calculate covariance matrix
x_cov = torch.bmm(x_vec, x_vec.permute(0, 2, 1)) / x_vec.shape[2]
# This line is only important if you get artifacts in the output image
x_cov = x_cov + (eps * torch.eye(C, device=x_input.device)[None, :])
return x_mean, x_vec, x_cov
def pca(x: torch.Tensor) -> torch.Tensor:
"""Perform principal component analysis"""
eigenvalues, eigenvectors = torch_symeig_eigh(x)
e = torch.sqrt(torch.diag_embed(eigenvalues.reshape(eigenvalues.size(0), -1)))
# Remove any NaN values if they occur
if torch.isnan(e).any():
e = torch.where(torch.isnan(e), torch.zeros_like(e), e)
return torch.bmm(torch.bmm(eigenvectors, e), eigenvectors.permute(0, 2, 1))
# Collect & calculate required values
_, input_vec, input_cov = get_mean_vec_and_cov(input, eps)
source_mean, _, source_cov = get_mean_vec_and_cov(source, eps)
# Calculate new cov matrix for input
if mode == "pca":
new_cov = torch.bmm(pca(source_cov), torch.inverse(pca(input_cov)))
elif mode == "cholesky":
new_cov = torch.bmm(
torch_cholesky(source_cov), torch.inverse(torch_cholesky(input_cov))
)
elif mode == "sym":
p = pca(input_cov)
pca_out = pca(torch.bmm(torch.bmm(p, source_cov), p))
new_cov = torch.bmm(torch.bmm(torch.inverse(p), pca_out), torch.inverse(p))
else:
raise ValueError(
"mode has to be one of 'pca', 'cholesky', or 'sym'."
+ " Received '{}'.".format(mode)
)
# Multiply input vector by new cov matrix
new_vec = torch.bmm(new_cov, input_vec)
# Reshape output vector back to input's shape &
# add the source mean to our output vector
return new_vec.reshape(input.shape) + source_mean
# Example for standard PyTorch images with value ranges of [0-1]
matched_image = color_transfer(target_image, source_image).clamp(0, 1)
The inner functions can be eliminated easily for TorchScript / JIT compatibility, and it's fully autograd compatible.
@ProGamerGov did this ever get implemented into any library?
It is often useful (especially in the field of astronomy) to transform the histogram of images. I would like to suggest an image histogram transformation function (under torchvision.transforms) that transforms the histogram of an image to match that of a template image as closely as possible. For instance, consider the following function:
The function above is not optimal since it has to recalculate template image information. It is not discretized for float type images. It only performs for highly discretized images such as png (0-255 bins). It also performs poorly when the number of diverse pixels is too low which might be fixed by adding small noise.