Open jbp70 opened 3 years ago
Hi,
I have same issue with similar code:
class WeightStandardization(tf.keras.constraints.Constraint):
def __call__(self, w):
mean = tf.math.reduce_mean(w, axis=[0, 1, 2], keepdims=True)
std = tf.math.reduce_std(w, axis=[0, 1, 2], keepdims=True)
return (w - mean) / tf.maximum(std, tf.keras.backend.epsilon())
@jbp70, any progress please? @joe-siyuan-qiao Did you have any idea?
Hi @markub3327. I had tried using this code a long while ago. I was unsuccessful in figuring out how to implement this properly and ended up simply moving on. If you figure it out, do let me know!
@jbp70 It must be done before computing grad. I believe this process must be done during inference and after that must be done a normalization. These constraints do it after gradients are calculated. This process needs to be kernel substituted with a normalized version and used it for prediction. After that, the gradients be applied.
@jbp70 What about this implementation?
class WeightStandardization(tf.keras.constraints.Constraint):
def __call__(self, w):
mean, variance = tf.nn.moments(w, axes=[0, 1, 2], keepdims=True)
std = tf.sqrt(variance)
epsilon = tf.keras.backend.epsilon()
return (w - mean) / tf.maximum(std, epsilon)
Hi! I am currently trying to implement your code in Tensorflow 2.2, but am running into errors. I get some strange error when trying to place the standardization directly before calling the convolution (as you said to do in post #11 ). For reference, the error is: "TypeError: An op outside of the function building code is being passed a "Graph" tensor."
I wasn't able to figure out how to correct this so I decided to create a custom kernel constraint and use that for the weight standardization.
`class CustomWeightStandardization(tf.keras.constraints.Constraint):
As far as I understand, this should be equivalent to how you have implemented weight standardization. I am training an image segmentation model using a 3D U-Net architecture trained from scratch on a medical imaging dataset I have. Unfortunately, turning on this kernel constraint makes the model perform worse than when I train without it. Do you have any ideas on how to fix this?