comp-imaging / ProxImaL

A domain-specific language for image optimization.
MIT License
112 stars 29 forks source link

Extending Lin_Op #24

Closed grau4 closed 7 years ago

grau4 commented 7 years ago

Hello everyone,

I am currently extending the 'Lin_Op' class by creating a new class 'matrix_nonsquared_mult2'. This operator should perform the following operation:

C * i - h

where:

As can be noted, unlike "mul_elemwise", the dimensions of my objects are all different, so creating this new class by just modifying the forward and adjoint methods of "mul_elemwise" wont be sufficient as some reshaping must be done.

I have defined the constructor so that shapes of both, operator "C" and variable "i" match during the execution of the "forward" method. By the point where the bug occurs, the "adjoint" method hasn't even been referenced yet.

The bug occurs during the execution of the "get_diag" method within my class, during the "partition/split" phase. Specifically, in the line:

self_diag = np.reshape(self.kernel, self.size)

The reason for the bug is that "self.kernel" contains, well, the 442x442 operator; but "self.size" contains the size of the 442x19800 variable. In the cases of "conv" and "mul_elemwise" modules such conflict never exists: In the first case the kernel is reshaped within the operator's constructor to match the variable size and then the operator acting on the input is a element-wise multiplication; while in the second case, both, kernel and input variable must have same dimensions from the beginning.

At the moment, I can't figure out how to make this to match. I have tried to store different shapes within the constructor of my class but then immediately the program breaks during the "absorption" stage at the "forward" method execution, since shapes do not match during the matrix-matrix multiplication. So, where should I do the reshaping and what are the correct shapes/sizes to store?? I think I don't really understand what is each shape or size for.

I would appreciate any suggestions or ideas. I attach the code below Thank you in advance. Javier

`class matrix_nonsquared_mult2(LinOp):

def __init__(self, kernel, arg, implem = None):
    #assert arg.shape == kernel.shape     #Now we must admit different dimensions
    self.kernel = kernel
    self.forward_kernel = kernel
    shape = arg.shape

    super(matrix_nonsquared_mult2, self).__init__([arg], shape, implem)

def init_kernel(self):
    print('Kernel Initialization: matrix_nonsquared_mult2')

def forward(self, inputs, outputs):
    """The forward operator.

    Reads from inputs and writes to outputs.
    """
    self.forward_kernel
    np.copyto(outputs[0], np.dot(self.kernel,inputs[0])) 

def adjoint(self, inputs, outputs):
    """The adjoint operator.

    Reads from inputs and writes to outputs.
    """
    print('To be implemented...')
    self.adjoint_kernel = self.forward_kernel.conj()

def is_diag(self, freq=False):

    return not freq and self.input_nodes[0].is_diag(freq)

def get_diag(self, freq=False):
    """Returns the diagonal representation (A^TA)^(1/2).

    Parameters
    ----------
    freq : bool
        Is the diagonal representation in the frequency domain?
    Returns
    -------
    dict of variable to ndarray
        The diagonal operator acting on each variable.
    """
    assert not freq

    var_diags = self.input_nodes[0].get_diag(freq)
    self_diag = np.reshape(self.kernel, self.size)
    for var in var_diags.keys():
        var_diags[var] = var_diags[var]*self_diag
    return var_diags

def norm_bound(self, input_mags):
    """Gives an upper bound on the magnitudes of the outputs given inputs.

    Parameters
    ----------
    input_mags : list
        List of magnitudes of inputs.

    Returns
    -------
    float
        Magnitude of outputs.
    """
    return np.max(np.abs(self.kernel))*input_mags[0]`
SteveDiamond commented 7 years ago

Delete the get_diag and is_diag functions. They're not what you want here. Also the norm_bound is wrong, so delete that as well.

grau4 commented 7 years ago

Thank you very much, Steve. Bug solved.