KlugerLab / SpectralNet

Deep network that performs spectral clustering
MIT License
321 stars 103 forks source link

Orthonorm layer weights #14

Open killandy opened 5 years ago

killandy commented 5 years ago

Hi, I have some trouble about the orthogonality training steps. I noticed that the orthogonality layer is actually a Keras Lambda layer which only do the simple operation of inputs. More over, the lambda layer has no trainable weights which is not consistent with the paper: In each orthogonalization step we use the QR decomposition to tune the weights of the last layer.

I have read the code carefully, and I still can't figure out how to train the orthogonality layer to get the orthogonality output. I'd appreciate it if you could explain it.

lihenryhfl commented 5 years ago

Hi,

Thanks for looking at our code. Please note that the weights are tuned by QR decomposition, which means that it is not a 'trainable weight' in the traditional sense, since finding it does not involve gradient descent. Instead, it is simply computed by Cholesky decomposition each orthogonalization step.

In particular, when the orthogonality layer is applied (1), we call Orthonorm() (2) (which creates the Lambda layer you are talking about):

def Orthonorm(x, name=None):
    '''
    Builds keras layer that handles orthogonalization of x

    x:      an n x d input matrix
    name:   name of the keras layer

    returns:    a keras layer instance. during evaluation, the instance returns an n x d orthogonal matrix
                if x is full rank and not singular
    '''
    # get dimensionality of x
    d = x.get_shape().as_list()[-1]
    # compute orthogonalizing matrix
    ortho_weights = orthonorm_op(x)
    # create variable that holds this matrix
    ortho_weights_store = K.variable(np.zeros((d,d)))
    # create op that saves matrix into variable
    ortho_weights_update = tf.assign(ortho_weights_store, ortho_weights, name='ortho_weights_update')
    # switch between stored and calculated weights based on training or validation
    l = Lambda(lambda x: K.in_train_phase(K.dot(x, ortho_weights), K.dot(x, ortho_weights_store)), name=name)

    l.add_update(ortho_weights_update)
    return l

Note that the Lambda layer takes the layers inputs, and applies either ortho_weights or ortho_weights_store, depending on the training phase. Both of these are matrices calculated by orthonorm_op 3, which is the QR decomposition that is mentioned in the paper.

killandy commented 5 years ago

It means the matrix will change (not update) in every orthogonalization step without previous information (gradient) ? So how to make sure the final matrix we want?

lihenryhfl commented 5 years ago

Note that the orthogonalization layer orthogonalizes the input, and nothing else. Unlike fully connected, recurrent, convolutional, etc., layers, which are functions that depend both on the input AND a set of parameters (i.e. weights and biases), the orthogonalization layer is one that depends only on the input (i.e., given a fixed input, it performs a fixed operation -- namely the orthogonalization of the fixed input). So your question of finding the 'correct' final matrix is the wrong way to think about this layer. (Somewhat akin to asking about how to find the 'correct' ReLU function; there is simply no tunable parameter in the fixed operation.)

Since the orthogonalization layer is differentiable, the learning signal passes through the orthogonalization layer to the previous layers that do have tunable parameters. These parameters are trained such that the output of the orthogonalization layer is indeed the desired output -- that is, it approximates the solution of the constrained optimization task.

block98k commented 2 years ago

Hello. I have a question about the "alternatively training", why you divide the training into two steps. Why not do them in one same mini batch: get y_tilde then compute the L, tune the weight of last layer and finally get the final output y. I think it can guarantee the orthogonalization for any input. Is there any problem if doing so? Thanks.