Open peterjliu opened 3 years ago
When referring to a Module's internal variables in call, one needs to use self.x.value instead of simply self.x. It'd be nice to enable this syntactic sugar to improve the readability of complex math expressions. For example, tf.module allows this
self.x.value
self.x
tf.module
class Dense(tf.Module): def __init__(self, in_features, out_features, name=None): super(Dense, self).__init__(name=name) self.w = tf.Variable( tf.random.normal([in_features, out_features]), name='w') self.b = tf.Variable(tf.zeros([out_features]), name='b') def __call__(self, x): y = tf.matmul(x, self.w) + self.b return tf.nn.relu(y)
I will implement it after JAX team will make this change: https://github.com/google/jax/pull/4725
When referring to a Module's internal variables in call, one needs to use
self.x.value
instead of simplyself.x
. It'd be nice to enable this syntactic sugar to improve the readability of complex math expressions. For example,tf.module
allows this