pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.26k stars 6.96k forks source link

[feature request] Image Histogram Transformation #598

Open Miladiouss opened 6 years ago

Miladiouss commented 6 years ago

It is often useful (especially in the field of astronomy) to transform the histogram of images. I would like to suggest an image histogram transformation function (under torchvision.transforms) that transforms the histogram of an image to match that of a template image as closely as possible. For instance, consider the following function:

def match_histogram(source, template):

    source   = np.asanyarray(source)
    template = np.asanyarray(template)
    oldshape = source.shape
    source   = source.ravel()
    template = template.ravel()

    # get the set of unique pixel values and their corresponding indices and
    # counts
    s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
                                            return_counts=True)
    t_values, t_counts = np.unique(template, return_counts=True)

    # take the cumsum of the counts and normalize by the number of pixels to
    # get the empirical cumulative distribution functions for the source and
    # template images (maps pixel value --> quantile)
    s_quantiles  = np.cumsum(s_counts).astype(np.float32)
    s_quantiles /= s_quantiles[-1]
    t_quantiles  = np.cumsum(t_counts).astype(np.float32)
    t_quantiles /= t_quantiles[-1]

    # interpolate linearly to find the pixel values in the template image
    # that corresponds most closely to the quantiles in the source image
    interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)

    return interp_t_values[bin_idx].reshape(oldshape)

The function above is not optimal since it has to recalculate template image information. It is not discretized for float type images. It only performs for highly discretized images such as png (0-255 bins). It also performs poorly when the number of diverse pixels is too low which might be fixed by adding small noise.

fmassa commented 6 years ago

Thanks for the issue!

I think we could provide a histogram transformation functionality in torchvision. Maybe one possibility would be to allow the user to pass in directly the target histogram, instead of passing the image, and provide a simple functionality to compute the histogram of an image.

Also, apart from np.interp, all the other functions have torch equivalents, so maybe we could make it use torch functions whenever possible?

Also, could you send a PR?

Miladiouss commented 6 years ago

Thank you @fmassa for your reply. The function above was written outside the context of PyTorch. That's why it's all NumPy. If you are going to integrate my code, I need some time (~ 1 month) to write it properly. As you said, I will use torch functions whenever possible and pass a target histogram instead. That's easy. But the user should be able to

My biggest worry is applying a histogram transform on floats. I need to see how other people do this (if you know any references, please send it my way). Otherwise, the way I would do it is to divide the input image/tensor to $b$ ranges where $b$ equals to the number of bins of the provided histogram. Then we need to truncate values by choosing min/max values. These limits can be found by finding the location where a small portion $p$ of the data falls below/above those values. By implementing this we could address image processing for non-standard formats such as RAW in photography and FITS in astronomy, as well as sound files.

By putting more effort into it, we could provide template histograms eliminating the need for gamma correction, contrast adjustments, and mean and standard deviation shifting.

Do you think I should send the pull request before addressing the big issues above or after (I'm still new to the GitHub community)?

fmassa commented 6 years ago

About points 1-2: what about the following: if the user provides a 1d histogram, it performs the same equalization for all channels (as if it was broadcast for all channels), and if it is a 2d histogram, then each channel uses one of the provided histograms. One of the limitations of this approach is that the number of elements of all histograms should be the same, but this is usually fine for uint8 images.

If we pass the bins of the histogram as well, we don't need to worry if the image is floating point or not, so we would pass not only the counts but also what value each count accounts for. I think this would solve your comments, right?

I don't have any experience with RAW or FITS images, but please feel free to send a PR with what we have discussed. Also, raising any issues you might see with it is definitely valuable!

Miladiouss commented 6 years ago

I like that! Then I'll start writing the function hist_transform(input_tensor, hist_bin, hist_count), all torch tensors and inform you when done.

Miladiouss commented 6 years ago

Hello @fmassa again. So I finished polishing and testing the code as we discussed. It is very inconvenient to switch back and forth between torch.tensor and numpy.ndarray, so I decided to do everything in numpy. I am very new to git and GitHub and I don't know how properly do a pull request. If you can help me with that, I'll appreciate it. See this gist for the code which includes a module test. So, the part that includes the major code is above # For tests and demonstration.

Miladiouss commented 5 years ago

I have made a pull request here. Also, see this gist for tests and proof of concept.

gheaeckkseqrz commented 5 years ago

Hey 😃

I was browsing throught the vison issues and found that one, turns out I actually did some work on histogram specification some time ago. Something like that :

D89b8U6XsAEVk4n

I wrote it as a cuda module as I was running the transform in an optimisation loop and needed it to be fast. The code is available over here if that can be useful : https://github.com/pierre-wilmot/NeuralTextureSynthesis/ Happy to help cleaning it up if you think it's worth adding to the vision repo.

fmassa commented 5 years ago

@gheaeckkseqrz Thanks for the proposal! I think it could be a nice addition to have efficient GPU-accelerated transforms, but first we need to have a reference implementation and I have to find the time to review the PR in #796

ProGamerGov commented 4 years ago

@gheaeckkseqrz's proposal would also really help with style transfer and potentially GAN related tasks as well. Support for histogram matching on GPU with tensors can be extremely useful for style transfer. https://github.com/ProGamerGov/neural-style-pt/issues/46#issuecomment-563005587

ProGamerGov commented 2 years ago

@Miladiouss So, I created this function that essentially matches the histogram of one image to another image, and it should hopefully help individuals with use cases like astronomy & neural style transfer.

I wrote the code for a different PyTorch project (pytorch/captum), but Torchvision is free to use it as well! @fmassa

def color_transfer(
    input: torch.Tensor,
    source: torch.Tensor,
    mode: str = "pca",
    eps: float = 1e-5,
) -> torch.Tensor:
    """
    Transfer the colors from one image tensor to another, so that the target image's
    histogram matches the source image's histogram. Applications for image histogram
    matching includes neural style transfer and astronomy.

    The source image is not required to have the same height and width as the target
    image. Batch and channel dimensions are required to be the same for both inputs.

    Gatys, et al., "Controlling Perceptual Factors in Neural Style Transfer", arXiv, 2017.
    https://arxiv.org/abs/1611.07865

    Args:

        input (torch.Tensor): The NCHW or CHW image to transfer colors from source
            image to from the source image.
        source (torch.Tensor): The NCHW or CHW image to transfer colors from to the
            input image.
        mode (str): The color transfer mode to use. One of 'pca', 'cholesky', or 'sym'.
            Default: "pca"
        eps (float): The desired epsilon value to use.
            Default: 1e-5

    Returns:
        matched_image (torch.tensor): The NCHW input image with the colors of source
            image. Outputs should ideally be clamped to the desired value range to
            avoid artifacts.
    """

    assert input.dim() == 3 or input.dim() == 4
    assert source.dim() == 3 or source.dim() == 4
    input = input.unsqueeze(0) if input.dim() == 3 else input
    source = source.unsqueeze(0) if source.dim() == 3 else source
    assert input.shape[:2] == source.shape[:2]

    # Handle older versions of PyTorch
    torch_cholesky = (
        torch.linalg.cholesky if torch.__version__ >= "1.9.0" else torch.cholesky
    )

    def torch_symeig_eigh(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        torch.symeig() was deprecated in favor of torch.linalg.eigh()
        """
        if torch.__version__ >= "1.9.0":
            L, V = torch.linalg.eigh(x, UPLO="U")
        else:
            L, V = torch.symeig(x, eigenvectors=True, upper=True)
        return L, V

    def get_mean_vec_and_cov(
        x_input: torch.Tensor, eps: float
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Convert input images into a vector, subtract the mean, and calculate the
        covariance matrix of colors.
        """
        x_mean = x_input.mean(3).mean(2)[:, :, None, None]

        # Subtract the color mean and convert to a vector
        B, C = x_input.shape[:2]
        x_vec = (x_input - x_mean).reshape(B, C, -1)

        # Calculate covariance matrix
        x_cov = torch.bmm(x_vec, x_vec.permute(0, 2, 1)) / x_vec.shape[2]

        # This line is only important if you get artifacts in the output image
        x_cov = x_cov + (eps * torch.eye(C, device=x_input.device)[None, :])
        return x_mean, x_vec, x_cov

    def pca(x: torch.Tensor) -> torch.Tensor:
        """Perform principal component analysis"""
        eigenvalues, eigenvectors = torch_symeig_eigh(x)
        e = torch.sqrt(torch.diag_embed(eigenvalues.reshape(eigenvalues.size(0), -1)))
        # Remove any NaN values if they occur
        if torch.isnan(e).any():
            e = torch.where(torch.isnan(e), torch.zeros_like(e), e)
        return torch.bmm(torch.bmm(eigenvectors, e), eigenvectors.permute(0, 2, 1))

    # Collect & calculate required values
    _, input_vec, input_cov = get_mean_vec_and_cov(input, eps)
    source_mean, _, source_cov = get_mean_vec_and_cov(source, eps)

    # Calculate new cov matrix for input
    if mode == "pca":
        new_cov = torch.bmm(pca(source_cov), torch.inverse(pca(input_cov)))
    elif mode == "cholesky":
        new_cov = torch.bmm(
            torch_cholesky(source_cov), torch.inverse(torch_cholesky(input_cov))
        )
    elif mode == "sym":
        p = pca(input_cov)
        pca_out = pca(torch.bmm(torch.bmm(p, source_cov), p))
        new_cov = torch.bmm(torch.bmm(torch.inverse(p), pca_out), torch.inverse(p))
    else:
        raise ValueError(
            "mode has to be one of 'pca', 'cholesky', or 'sym'."
            + " Received '{}'.".format(mode)
        )

    # Multiply input vector by new cov matrix
    new_vec = torch.bmm(new_cov, input_vec)

    # Reshape output vector back to input's shape &
    # add the source mean to our output vector
    return new_vec.reshape(input.shape) + source_mean

# Example for standard PyTorch images with value ranges of [0-1]
matched_image = color_transfer(target_image, source_image).clamp(0, 1)

The inner functions can be eliminated easily for TorchScript / JIT compatibility, and it's fully autograd compatible.

vedal commented 1 year ago

@ProGamerGov did this ever get implemented into any library?