tensorflow / kfac

An implementation of KFAC for TensorFlow
Apache License 2.0
197 stars 41 forks source link

ValueError: Found the following errors with variable registration #31

Closed StanleyGan closed 4 years ago

StanleyGan commented 4 years ago

Example of the error is:

Variable <tf.Variable 'layer1/W:0' shape=(32,5) dtype=tf.float32_ref> registered with wrong number of uses (1 registrations vs 4 uses)

My loss is a customized loss which involves taking the gradient and the hessian of the output w.r.t. input. For example,

g1 = tf.gradients(output, x_placeholder)[0]
g2 = tf.gradients(g1, x_placeholder)[0]
loss = tf.reduce_sum(g1) + tf.reduce_sum(g2)

I then registered this loss with layer_collection.register_squared_error_loss(loss). I've dived deep into the source of the error and found that for the variable consumers in the function variable_uses() in kfac/ops/utils.py, it is of the following: {<tf.Operation 'MatMul' type=BatchMatMulV2>, <tf.Operation 'gradients_1/gradients/MatMul_grad/MatMul_grad/MatMul' type=BatchMatMulV2>, <tf.Operation 'gradients/MatMul_grad/MatMul' type=BatchMatMulV2>, <tf.Operation 'gradients_1/MatMul_grad/MatMul' type=BatchMatMulV2>}, explaining the reason of 4 uses. When I removed g2 and return loss=g1 instead, I have 2 uses.

So the questions are: 1) How do I make sure that I register with the correct number of uses especially when the loss involves gradients? 2) Is there a tutorial or example for using KFAC on customized loss?

james-martens commented 4 years ago

I don't really understand the definition of your loss, but I'm guessing it won't be supported by K-FAC. Not just in the code but at the algorithm level. You could potentially just register the loss as something else that is supported and hope for the best, but you will be in uncharted territory.

Going back to your definition, you seem to be defining it as gradient of something. Note that the loss needs to be a scalar. Is it here?

StanleyGan commented 4 years ago

The loss that I am trying to minimize is quite complicated to define here. Hence, I used a simple example as above but generally the tricky part involves taking gradients. Yes, it will be a scalar after I apply tf.reduce_sum(). I edited my question.

james-martens commented 4 years ago

Well, if you loss involves gradients that depend on the network it seems to me that this is a structure that is incompatible with the assumptions that K-FAC is making. In particular, you objective can no longer simply be written as L( f(x, theta), y ) where f is a neural network.

StanleyGan commented 4 years ago

Thanks for your reply! I am actually trying to reproduce this paper where they used KFAC to optimize local energy for a use case in many-electrons Schrödinger equation: https://arxiv.org/pdf/1909.02487.pdf. If you look at page 4 of the paper, the local energy involves the square of the gradient of the output w.r.t. electron coordinates and the Laplacian of output w.r.t. the coordinates. Based on the KFAC tutorial, the loss needs to be registered. So, using KFAC I registered the local energy as loss.

james-martens commented 4 years ago

You may want to read that paper more carefully and possibly contact the authors. Perhaps they even have code that could be shared. I don't think K-FAC is being applied straightforwardly to the loss itself in their case. Rather, that paper is taking more of a "natural gradient" view where the curvature matrix is no longer an approximation of the Hessian of the loss, but rather is the Fisher associated with the distribution generated by the model. (In some scenarios these things are equivalent, but not so in their case.)