InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.15k stars 228 forks source link

feat: SpectralConv2d, SpectralConvTranspose2d #258

Closed YodaEmbedding closed 6 months ago

YodaEmbedding commented 10 months ago

Introduced in "Efficient Nonlinear Transforms for Lossy Image Compression" by Johannes Ballé, PCS 2018. Reparameterizes the weights to be derived from weights stored in the frequency domain. In the original paper, this is referred to as "spectral Adam" or "Sadam" due to its effect on the Adam optimizer update rule. The motivation behind representing the weights in the frequency domain is that optimizer updates/steps may now affect all frequencies to an equal amount. This improves the gradient conditioning, thus leading to faster convergence and increased stability at larger learning rates.

spectral_adam

spectral_adam_rd_curves

For comparison, see the TensorFlow Compression implementations of SignalConv2D and RDFTParameter. They seem to use SignalConv2d in most of their provided architectures: https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code

Furthermore, since this is a simple invertible transformation on the weights, it is trivial to convert any existing pretrained weights into this form via:

weight_transformed = self._to_transform_domain(weight)

To override self.weight as a property, I'm unregistering the module using del self._parameters["weight"] as shown in https://github.com/pytorch/pytorch/issues/46886, and also using the fact that @property returns a descriptor object so that self.weight "falls back" to the property.

    def __init__(self, ...):
        self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight))
        del self._parameters["weight"]  # Unregister weight, and fallback to property.

    @property
    def weight(self) -> Tensor:
        return self._from_transform_domain(self.weight_transformed)

Checklist:

YodaEmbedding commented 10 months ago

SpectralConv2d vs Conv2d mini-experiments

Below are a couple of example runs (with different randomly initialized kernels) to compare SpectralConv2d vs Conv2d. (Ignore the inaccurate titles; I accidentally set the title for everything to "smoothing kernel".)


sinusoidal channel averaging:

conv_kwargs = dict(in_channels=3, out_channels=2, kernel_size=11, padding=5)

def init_conv_target(conv):
    k = 0.05 * torch.linspace(0, 2 * torch.pi, conv.kernel_size[0], device=device).sin()
    conv.weight.data[:] = k

fit_spectralconv2d

fit_spectralconv2d

fit_spectralconv2d

fit_spectralconv2d


Simple depthwise cosine "smoothing" kernel:

conv_kwargs = dict(in_channels=8, out_channels=8, kernel_size=11, padding=5)

def init_conv_target(conv):
    k = torch.linspace(0, torch.pi, conv.kernel_size[0], device=device).sin()
    k = k * k[:, None]
    k = k / k.sum()
    idx = torch.arange(conv.in_channels, device=device)
    conv.weight.data[:] = 0
    conv.weight.data[idx, idx, :, :] = k

fit_spectralconv2d


Simple depthwise vertical edge-detector kernel:

conv_kwargs = dict(in_channels=4, out_channels=4, kernel_size=5, padding=5, groups=4)

def init_conv_target(conv):
    K = conv.kernel_size[0]
    k = 2 * torch.linspace(-1, 1, K, device=device) / K**2
    conv.weight.data[:] = k

fit_spectralconv2d


Random initialization:

conv_kwargs = dict(in_channels=4, out_channels=4, kernel_size=5, padding=5)

def init_conv_target(conv):
    pass  # Maintain random initialization.

fit_spectralconv2d

Surprisingly, SpectralConv2d is not worse...!


Figures generated via:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from compressai.layers import SpectralConv2d

device = "cuda"

def train(conv, conv_target, max_steps=1000, lr=1e-4, batch_size=16):
    losses = []
    for step in range(max_steps):
        conv.zero_grad()
        x = torch.rand((batch_size, conv.in_channels, 256, 256)).to(device)
        y = conv_target(x).detach()
        y_hat = conv(x)
        loss = ((y - y_hat) ** 2).mean()
        loss.backward()
        for param in conv.parameters():
            param.data -= param.grad * lr
        print(f"step {step:04d}  loss {loss.item():.4f}")
        losses.append(loss.item())
    return losses

models = {
    "spectral": SpectralConv2d(**conv_kwargs),
    "regular": nn.Conv2d(**conv_kwargs),
}

# Initialize all models to exact same random initialization.
conv_rand = nn.Conv2d(**conv_kwargs).to(device)
init_weight = conv_rand.weight.data
init_bias = conv_rand.bias.data

for key, conv in models.items():
    conv.to(device)
    if isinstance(conv, SpectralConv2d):
        conv.weight_transformed.data = conv._to_transform_domain(init_weight).clone()
    else:
        conv.weight.data = init_weight.clone()
    conv.bias.data = init_bias.clone()

# Initialize "ideal"/target model kernels using the functions defined above.
conv_target = nn.Conv2d(**conv_kwargs).to(device)
init_conv_target(conv_target)

results = {key: train(conv, conv_target) for key, conv in models.items()}

fig, ax = plt.subplots()
for key, y in results.items():
    ax.plot(y, label=key)
ax.legend()
ax.set(xlabel="step", ylabel="loss", title="Fitting to a target kernel")
fig.savefig("fit_spectralconv2d.png")

Random musings:

I wonder why no one's applied this to other image problems (classification/superresolution/etc)? Those problems should also be concerned with "frequency regularized" kernels (so to speak).

I wonder if this could be turned into a short "paper" with more interesting experiments... Or if not, perhaps at least a "note" (1, 2) on arxiv.

Also, another related thought that I had once upon a time:

Generate the kernel from a different basis, e.g. the first $K=4$ 2D DCT basis elements. Assuming a single input and output channel, `y = conv(x, weight=k(w))`, where $w = (w_1, \ldots, w_K)$ are trainable weights and $k(w)$ generates the kernel via the weighted summation $k(w) = \sum_i w_i e_i$. ...Then, even large 7x7 or 9x9 kernels can be expressed by only a few parameters. Extend appropriately to multiple input/output channels. I guess this looks very related to "dynamic convolution" / "CondConv", but with 10x fewer parameters instead of 4x more parameters. I think I like Balle's insight of using a big DFT and calling it a day, though. I guess one more thing we could do as an extension to Balle's "Spectral" reparameterization is something like: ```python def _from_transform_domain(self, w: Tensor) -> Tensor: # Attenuate out high-frequencies. # Not sure if correctly written, but this is intended to keep the lower frequencies, and push the others to 0. # d/dw is then [hopefully larger for the lower frequencies, but I should check if I wrote this correctly...]. yy = torch.linspace(1, 0.1, w.shape[-2], device=w.device)[None, :] xx = torch.linspace(1, 0.1, w.shape[-1], device=w.device) mask = yy * xx w = w * mask # Balle's original spectral reparametrization. return torch.fft.irfftn(w, s=self.kernel_size, dim=self.dim, norm="ortho") ``` Or if high frequencies are also important, do this "regularizing" reparameterization for only some portion of the weights. That way, the network is forced into a balance of both, inhabits a ~~lower dimensional space~~ well-conditioned/"regularized" space (c.f. dynamic conv's affine constraint), and is roughly just as expressible as an unparameterized network. TODO: Create a test suite of target kernels like the above, or pulled from trained networks, and see how quickly different `_from_transform_domain` functions converge... Is there something that fits our networks better than Balle's unmodified spectral conv?
YodaEmbedding commented 10 months ago

SpectralConv2d vs Conv2d experiments

First attempt

++model.name="bmshj2018-factorized" ++criterion.lmbda=0.0067 ++scheduler.net.threshold=5e-4

Training forward()-loss curves comparison
Blue: Conv2d. Green: SpectralConv2d.
Validation forward()-loss curves comparison
Blue: Conv2d. Green: SpectralConv2d. No clamp(0, 1) applied to x_hat.
Inference/test forward()-loss curves comparison
Blue: Conv2d. Green: SpectralConv2d. No clamp(0, 1) applied to x_hat.

However, even though the inference/test loss is evidently much smaller, the position of the point on the RD plot is actually less "optimal" w.r.t. the default configuration:

Conv2d SpectralConv2d

One possibility to mitigate this might be to initially train with the spectral transform enabled for the first 30 epochs, and then disable it for the remainder of the training.

Perhaps the "inaccurate" reconstruction loss ($D$) estimation is only a problem when "noise" quantization is used during training. Not sure. Maybe modifying the covariance structure (see Balle's paper, pg 2) also makes it easier for $g_s(y + \mathcal{N}(0, 1))$ to shift the distribution and generate relatively more "optimistic" $\hat{x}$ than rounding would...? Maybe STE might help...

Contour lines of equal loss for $\lambda=0.0067$
Click to see code for generating loss contour plot. ```python import json import matplotlib.pyplot as plt import numpy as np RESULTS_DIR = "/home/mulhaq/code/research/compressai/master/results" CODECS = [f"{RESULTS_DIR}/image/kodak/compressai-bmshj2018-factorized_mse_cuda.json"] def mse_to_psnr(mse, max_value=1.0): return -10 * np.log10(mse / max_value**2) def psnr_to_mse(psnr, max_value=1.0): return 10 ** (-psnr / 10) * max_value**2 def main(): xlim = (0, 2.25) ylim = (26, 41) x = np.linspace(*xlim, 50) y = np.linspace(*ylim, 50) R = x D = psnr_to_mse(y) lmbdas = [0.0067] cmaps = ["Greys"] # lmbdas = [0.0018, 0.0035, 0.0067, 0.0130, 0.0250, 0.0483, 0.0932, 0.1800] # cmaps = ["Reds", "Oranges", "YlOrBr", "Greens", "Blues", "Purples", "Greys", "RdPu"] levels = np.logspace(-2, 0.5, 200) fig, ax = plt.subplots(figsize=(8, 6)) for lmbda, cmap in zip(lmbdas, cmaps): loss = R + lmbda * 255**2 * D[:, None] im = ax.contour(x, y, loss, levels=levels, cmap=cmap) cbar = fig.colorbar(im, ax=ax, fraction=0.08, pad=0.01) cbar.set_ticks([round(tick, 2) for tick in cbar.ax.get_yticks()]) # RD curves. for codec in CODECS: with open(codec, "r") as f: data = json.load(f) ax.plot( data["results"]["bpp"], data["results"]["psnr-rgb"], ".-", label=data["name"], ) # Custom points. ax_kwargs = dict(zorder=100, s=8) series = [ dict( x=[0.308173], y=[29.9231], label="Conv2d (compress/decompress)", color="C1", marker="*", ), dict( x=[0.322696], y=[30.0040], label="SpectralConv2d (compress/decompress)", color="C6", marker="*", ), dict( x=[0.307627], y=[29.1858], label="Conv2d (forward) (no clamp)", color="C1", ), dict( x=[0.322088], y=[29.3400], label="SpectralConv2d (forward) (no clamp)", color="C6", ), ] for series_i in series: ax.scatter(**series_i, **ax_kwargs) # Finalize. ax.set( xlabel="Bit-rate [bpp]", ylabel="PSNR [dB]", title="Loss surface", xlim=xlim, ylim=ylim, ) ax.legend(loc="lower right", fontsize="small") fig.savefig("loss_surface.png", dpi=300) if __name__ == "__main__": main() ```

Still, though... a point with lower $L$ actually has worse $(R,D)$ relative to the optimal achievable $(R, D)$ for a given model architecture?!

Code wars episodes I-III. EDIT I (The `eval` menace): The plotted PSNR was measured using `compress/decompress`, whereas the loss was measured using `forward`'s `mse_loss`, which is somehow much worse. What may be happening is that the loss shown in the first figure was probably being measured on the noise-quantized y_hat, rather than round-quantized y_hat. But doesn't eval mode disable noise quantization? EDIT II (The `.clone()` wars): Eval mode is certainly set, according to https://github.com/catalyst-team/catalyst/blob/v22.04/catalyst/core/runner.py#L312 which calls [`model.train(mode=False)`](https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/modules/module.py#L2269-L2289) (which recursively updates all submodules to `self.training=False`). It is correctly using `"dequantize"` for valid/infer rather than `"noise"`, so still I wonder what could be causing the worse `mse_loss`... EDIT III (Revenge of the `x_hat`): For `forward` and `compress/decompress`, both `y` and `y_hat` are exactly the same, but `x_hat` isn't. Thus, the problem only occurs during `g_s(y_hat)`.

EDIT IV (A new hope): Found it! There's no .clamp_(0, 1) in the forward, which makes sense during training.

def forward(...):
        x_hat = self.g_s(y_hat)

def decompress(...):
        x_hat = self.g_s(y_hat).clamp_(0, 1)

By clamp_ing the x_hat when not in training mode, both methods now give the exact same results. It's quite surprising that simple clamping causes such a big gain in PSNR. (Between 0.66 dB to 0.74 dB.) It's also weird that the SpectralConv2d-trained model only gets a smaller 0.66 dB jump but the Conv2d somehow enjoys 0.74 dB. But maybe that's just luck.

EDIT V (The scheduler strikes back): According to the training losses, it looks like the SpectralConv2d-trained model dropped LR earlier, whereas the Conv2d-trained model happily went on for an additional 100 epochs before the first LR drop! Here were the settings I used:

scheduler:
  net:
    type: "ReduceLROnPlateau"
    mode: "min"
    factor: 0.1
    patience: 10
    threshold: 5e-4  # default is 1e-4

The CompressAI Trainer config is using the defaults. It's just my personal config that has the 5e-4. Upon further reflection, it looks like that reduction in loss was slowing down anyways, so maybe scheduler.net.threshold=5e-4 is probably not that bad an idea. I'll give 1e-4 a try anyways.

EDIT VI (Return of the multi-stage training): Perhaps best is to initially train with spectral-mode, then auto-schedule switching to regular-mode some time before the first LR drop occurs.

That may (or may not?) require changing the training code, however, i.e., a new Runner? I wonder if PyTorch has schedulers that can work with regular modules, and not just optimizers. Otherwise, I guess I could create a pseudo-optimizer that doesn't actually optimize anything, but merely switches conv.enable_transform=False for all the convs when a pseudo-LR drop occurs... That's a convoluted way of avoiding adding a new runner. Maybe a callback is less convoluted. But then I have to write my own pseudo-LR drop logic...?!


Second attempt

++model.name="bmshj2018-factorized" ++criterion.lmbda=0.0067 (with default scheduler.net.threshold=1e-4)

Inference/test forward()-loss curves comparison
Green: Conv2d. Blue: SpectralConv2d. No clamp(0, 1) applied to x_hat.
Conv2d SpectralConv2d

Much better.

Multi-stage training might still be a good idea, though, since at around loss=0.655 (i.e. 40 epochs for SpectralConv2d) it looks like Conv2d starts converging more quickly than SpectralConv2d.

fracape commented 6 months ago

merged together with #270