v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
280 stars 36 forks source link

Extracting the coefficients #99

Closed rachelglenn closed 2 months ago

rachelglenn commented 3 months ago

Hi. I am trying to pull out the coefficients as a matrix. Similar to pywt. How do I do that (last line)?

import ptwt, pywt, torch
import numpy as np
import scipy.misc

face = np.transpose(scipy.datasets.face(),
                        [2, 0, 1]).astype(np.float64)
pytorch_face = torch.tensor(face)
coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
                                level=2, mode="constant")
reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
np.max(np.abs(face - reconstruction.squeeze(1).numpy()))

cA, (cH, cV, cD) = coefficients 
v0lta commented 3 months ago

Dear @rachelglenn , You are probably getting a ValueError: too many values to unpack (expected 2) because the code assigns a coefficient tuple of length 3 to an expression of length 2.

You have two options:

  1. You can add another detail tuple:
    
    import ptwt, pywt, torch
    import numpy as np
    import scipy.misc

face = np.transpose(scipy.datasets.face(), [2, 0, 1]).astype(np.float64) pytorch_face = torch.tensor(face) coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"), level=2, mode="constant") reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar")) np.max(np.abs(face - reconstruction.squeeze(1).numpy()))

(cA2, (cH2, cV2, cD2), (cH1, cV1, cD1)) = coefficients

2. You can set the decomposition to `1`:
``` python

import ptwt, pywt, torch
import numpy as np
import scipy.misc

face = np.transpose(scipy.datasets.face(),
                        [2, 0, 1]).astype(np.float64)
pytorch_face = torch.tensor(face)
coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
                                level=1, mode="constant")
reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
np.max(np.abs(face - reconstruction.squeeze(1).numpy()))

cA, (cH, cV, cD) = coefficients 

I hope this helps you solve your problem.

felixblanke commented 3 months ago

I am not quite sure what you mean with "coefficients as a matrix". In case you mean the functionality of pywt.coeffs_to_array I wrote a small snippet:

import ptwt, pywt, torch
import numpy as np
import scipy.misc

def coeffs_to_array(coeffs: ptwt.WaveletCoeff2d) -> torch.Tensor:
    cA = coeffs[0]

    for detail_coeffs in coeffs[1:]:
        cH, cV, cD = detail_coeffs
        row0 = torch.cat([cA, cV], dim=-1)
        row1 = torch.cat([cH, cD], dim=-1)
        cA = torch.cat([row0, row1], dim=-2)
    return cA

face = np.transpose(scipy.datasets.[2, 0, 1]).astype(np.float64)
pytorch_face = torch.tensor(face)

coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"), level=2, mode="constant")
coefficients_pywt = pywt.wavedec2(face, pywt.Wavelet("haar"), level=2, mode="constant")

ptwt_arr = coeffs_to_array(coefficients)
pywt_arr, _ = pywt.coeffs_to_array(coefficients_pywt, axes=(-2, -1))

assert ptwt_arr.shape == pywt_arr.shape
assert np.isclose(ptwt_arr.numpy(), pywt_arr).all()
v0lta commented 2 months ago

@rachelglenn, is your question answered?

v0lta commented 2 months ago

Closing due to inactivity. Feel free to reopen.