Stable-X / StableDelight

StableDelight: Revealing Hidden Textures by Removing Specular Reflections
212 stars 6 forks source link

Could this be used to extract the reflection? #2

Closed fever308 closed 1 month ago

fever308 commented 1 month ago

This is really neat, was wondering if it was possible to get the reflections themselves as an output too?

hugoycj commented 1 month ago

It's feasible! Actually, we have implemented a reflection detection function in another project, which takes reflection extraction as intermediate step:

def generate_specular(rgb_image, diffuse_image, kernel_size=15, threshold=2):
    """
    Generate a specular reflection map by subtracting the diffuse image from the RGB image using PyTorch.

    :param rgb_image: RGB image as a PIL Image
    :param diffuse_image: Diffuse image as a PIL Image
    :param kernel_size: Size of the box kernel for local smoothing
    :param threshold: Threshold for specular values and normalization
    :return: specular reflection map as a PIL Image
    """

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load and convert images to PyTorch tensors
    to_tensor = transforms.ToTensor()
    rgb_tensor = to_tensor(rgb_image).to(device)
    diffuse_tensor = to_tensor(diffuse_image).to(device)

    # Compute specular tensor
    specular_tensor = rgb_tensor - diffuse_tensor

    # Clip negative values to 0
    specular_tensor = torch.clamp(specular_tensor, min=0.0)

    # Apply local smoothing
    padding = kernel_size // 2
    specular_smoothed = F.avg_pool2d(specular_tensor, kernel_size, stride=1, padding=padding)

    # Normalize using the threshold
    specular_normalized = torch.clamp(specular_smoothed / threshold, min=0.0, max=1.0)

    # Convert to grayscale
    specular_gray = specular_normalized.mean(dim=0, keepdim=True)

    # Convert to 0-255 range
    specular_uint8 = (specular_gray * 255).permute(1, 2, 0)

    # Convert to PIL Image
    specular_image = Image.fromarray(specular_uint8.squeeze().cpu().numpy().astype(np.uint8))

    return specular_image