joe-siyuan-qiao / WeightStandardization

Standardizing weights to accelerate micro-batch training
546 stars 43 forks source link

Implementation question in Tensorflow #26

Open jbp70 opened 3 years ago

jbp70 commented 3 years ago

Hi! I am currently trying to implement your code in Tensorflow 2.2, but am running into errors. I get some strange error when trying to place the standardization directly before calling the convolution (as you said to do in post #11 ). For reference, the error is: "TypeError: An op outside of the function building code is being passed a "Graph" tensor."

I wasn't able to figure out how to correct this so I decided to create a custom kernel constraint and use that for the weight standardization.

`class CustomWeightStandardization(tf.keras.constraints.Constraint):

def __init__(self, axis=(0,1,2,3)):
    self.axis = axis
#
def __call__(self, w):
    mean = tf.math.reduce_mean(w, axis=self.axis, keepdims=True)
    std = tf.math.sqrt(tf.math.reduce_variance(w, axis=self.axis, keepdims=True) + tf.keras.backend.epsilon()) + 0.00001
    return (w - mean) / std
#
def get_config(self):
    return {'axis': self.axis}`

As far as I understand, this should be equivalent to how you have implemented weight standardization. I am training an image segmentation model using a 3D U-Net architecture trained from scratch on a medical imaging dataset I have. Unfortunately, turning on this kernel constraint makes the model perform worse than when I train without it. Do you have any ideas on how to fix this?

markub3327 commented 1 year ago

Hi,

I have same issue with similar code:

class WeightStandardization(tf.keras.constraints.Constraint):
    def __call__(self, w):
        mean = tf.math.reduce_mean(w, axis=[0, 1, 2], keepdims=True)
        std = tf.math.reduce_std(w, axis=[0, 1, 2], keepdims=True)
        return (w - mean) / tf.maximum(std, tf.keras.backend.epsilon())

@jbp70, any progress please? @joe-siyuan-qiao Did you have any idea?

jbp70 commented 1 year ago

Hi @markub3327. I had tried using this code a long while ago. I was unsuccessful in figuring out how to implement this properly and ended up simply moving on. If you figure it out, do let me know!

markub3327 commented 1 year ago

@jbp70 It must be done before computing grad. I believe this process must be done during inference and after that must be done a normalization. These constraints do it after gradients are calculated. This process needs to be kernel substituted with a normalized version and used it for prediction. After that, the gradients be applied.

markub3327 commented 1 year ago

@jbp70 What about this implementation?

class WeightStandardization(tf.keras.constraints.Constraint):
    def __call__(self, w):
        mean, variance = tf.nn.moments(w, axes=[0, 1, 2], keepdims=True)
        std = tf.sqrt(variance)
        epsilon = tf.keras.backend.epsilon()
        return (w - mean) / tf.maximum(std, epsilon)