rfeinman / pytorch-lasso

L1-regularized least squares with PyTorch
MIT License
63 stars 8 forks source link

ista_conv2d is solving the optimization problem for transposed convolutions #2

Closed MikeLasz closed 1 year ago

MikeLasz commented 2 years ago

Hi @rfeinman thanks for providing those practical tools.

However, I might have spotted a mistake in lasso.conv2d.ista.ista_conv2d. Actually, I believe that this function is not returning the solution to the case when $A$ is a convolution. Instead, it is returning the result when $A$ corresponds to a transposed convolution. The ISTA-step is

$$ x{k+1} = \mathcal{T}{\lambda t} \bigl( x_k - 2tA^T (Ax_k - b) \bigr) \enspace . $$

Comparing this to your ISTA-step in lasso.conv2d.ista.ista_conv2d:

def rss_grad(zk):
   x_hat = F.conv_transpose2d(zk, weight, stride=stride, padding=padding)
   return F.conv2d(x_hat-x, weight, stride=stride, padding=padding)

# ista step function
def step(zk):
    return F.softshrink(zk - lr * rss_grad(zk), alpha * lr)

Hence, your update step calculates

$$ x{k+1} = \mathcal{T}{\lambda t} \bigl( x_k - 2tA (A^Tx_k - b) \bigr) \enspace , $$

which corresponds to the optimization problem applied to transposed convolutions. You can run the following example snippet to test my claim:

# Example Script:
import torch 
import torch.nn.functional as F 
from lasso.conv2d.ista import ista_conv2d 

inp = torch.tensor([[1., 2, 1],
                     [2, 0, 1],
                     [3, 1, 1]]).reshape(1, 1, 3, 3)
weight = torch.tensor([[0.5, 1],
                         [2, 3]]).reshape(1, 1, 2, 2)
output = F.conv_transpose2d(inp, weight)
x0 = torch.randn_like(inp)
inp_ista = ista_conv2d(output, x0, weight, alpha=.0, lr=.01, maxiter=100)
print(inp_ista)

What do you think, @rfeinman ? Am I missing something? Otherwise, I would be happy to contribute by providing a fixed version that is able to handle transposed convolutions, as well as, regular convolutions. Furthermore, note that the Lipschitz-constant estimation procedure can remain the same since $\operatorname{Lip}(A) = \sigma(A) = \sigma(A^T) = \operatorname{Lip}(A^T)$ .

rfeinman commented 2 years ago

Hi @MikeLasz

Yes, you are correct, I have defined the linear operator $A x$ as pytorch's conv_transpose2d (and therefore $A^T x$ as conv2d). This choice was intentional, but perhaps it is confusing given the function name. In sparse coding, the operator $A$ is often referred to as the decoder; therefore I found it more fitting to associate $A$ with conv_transpose, which is more commonly used in convolutional decoders.

As you note, conv vs. conv_transpose is only a technicality, and switching the configuration is straightforward. For the case of stride=1 it's as simple as passing weight.transpose(0,1).flip([2, 3]) in place of weight (up to a difference in padding ops).