rahul-goel / fused-ssim

Lightning fast differentiable SSIM.
MIT License
62 stars 3 forks source link

Fully Fused Differentiable SSIM

This repository contains an efficient fully-fused implementation of SSIM which is differentiable in nature. There are several factors that contribute to an efficient implementation:

As per the original SSIM paper, this implementation uses 11x11 sized convolution kernel. The weights for it have been hardcoded and this is another reason for it's speed. This implementation currently only supports 2D images but with variable number of channels and batch size.

PyTorch Installation Instructions

Usage

import torch
from fused_ssim import fused_ssim

# predicted_image, gt_image: [BS, CH, H, W]
# predicted_image is differentiable
gt_image = torch.rand(2, 3, 1080, 1920)
predicted_image = torch.nn.Parameter(torch.rand_like(gt_image))
ssim_value = fused_ssim(predicted_image, gt_image)

By default, same padding is used. To use valid padding which is the kind of padding used by pytorch-mssim:

ssim_value = fused_ssim(predicted_image, gt_image, padding="valid")

If you don't want to train and use this only for inference, use the following for even faster speed:

with torch.no_grad():
  ssim_value = fused_ssim(predicted_image, gt_image, train=False)

Constraints

Performance

This implementation is 5-8x faster than the previous fastest (to the best of my knowledge) differentiable SSIM implementation pytorch-mssim.

BibTeX

If you leverage fused SSIM for your research work, please cite our main paper:

@inproceedings{taming3dgs,
    author={{Mallick and Goel} and Kerbl, Bernhard and
              Vicente Carrasco, Francisco and Steinberger, Markus and De La
              Torre, Fernando},
    title={Taming 3DGS: High-Quality Radiance Fields with Limited Resources},
    booktitle = {SIGGRAPH Asia 2024 Conference Papers},
    year={2024},
    doi = {10.1145/3680528.3687694},
    url = {https://humansensinglab.github.io/taming-3dgs/}
}

Acknowledgements

Thanks to Bernhard for the idea.