HUJI-Deep / simnets-tf

SimNets implementation in TensorFlow
MIT License
7 stars 3 forks source link

Dirichlet initializer API and usage #13

Closed orsharir closed 7 years ago

orsharir commented 7 years ago

The Dirichlet initializer should have a public interface (including documentation), which will accept an alpha value, and return the respective private Dirichlet initializer. This could be something like so:

def dirichlet_init(alpha=1.0):
    def _dirichlet_init(shape, dtype=None):
        if dtype is None:
            dtype = K.floatx()
        num_regions, num_instances, block_c, block_h, block_w = shape
        k = block_c * block_h * block_w
        # when given s as a size argument dirichlet function return an array with shape s + [k]
        # then we reshape the output to be of the same shape as the variable
        init_np = np.random.dirichlet([alpha] * k, size=(num_regions, num_instances)).astype(dtype)
        init_np = np.log(init_np)
        init_np = init_np.reshape(shape)
        return tf.constant(init_np)
    return _dirichlet_init

Notice the above code is already in the corrected form as in #12.