google / objax

Apache License 2.0
769 stars 77 forks source link

Remove need for .value when referring to internal param values. #111

Open peterjliu opened 3 years ago

peterjliu commented 3 years ago

When referring to a Module's internal variables in call, one needs to use self.x.value instead of simply self.x. It'd be nice to enable this syntactic sugar to improve the readability of complex math expressions. For example, tf.module allows this

 class Dense(tf.Module):
   def __init__(self, in_features, out_features, name=None):
     super(Dense, self).__init__(name=name)
     self.w = tf.Variable(
       tf.random.normal([in_features, out_features]), name='w')
     self.b = tf.Variable(tf.zeros([out_features]), name='b')
   def __call__(self, x):
     y = tf.matmul(x, self.w) + self.b
     return tf.nn.relu(y)
AlexeyKurakin commented 3 years ago

I will implement it after JAX team will make this change: https://github.com/google/jax/pull/4725