keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Constant Initializer supports only Scalar values #837

Closed Mrutyunjay01 closed 1 year ago

Mrutyunjay01 commented 1 year ago

While trying to port the following example End-to-end Masked Language Modeling with BERT to keras_core, faced the following issue as the modified layers.Embedding doesn't support `weights' argument to add positional encoding for MLM. Issue mentioned here and here

As a workaround, I made the following changes to adapt the positional encoding as embeddings_initializer:

    position_embeddings = layers.Embedding(
        input_dim=config.MAX_LEN,
        output_dim=config.EMBED_DIM,
        embeddings_initializer=Constant(get_pos_encoding_matrix(config.MAX_LEN, config.EMBED_DIM)),
        name="position_embedding",
    )(tf.range(start=0, limit=config.MAX_LEN, delta=1))

But as mentioned below, layers.Constant supports only scalar values.

https://github.com/keras-team/keras-core/blob/f0f8f1cb398954c050e4e3a46e1f062c8ed7c848/keras_core/initializers/constant_initializers.py#L32

Hence, to make it work, I change the constructor to the following:

    def __init__(self, value=0.0):
        self.value = value.astype('float')

The MLM Bert pretraining is going on as I'm raising this issue, with the mentioned modifications and other required adaptations fro keras_core. I will raise a PR once I reproduce the results. Any avid users of keras_core, kindly confirm this workaround. If this fits, will add this fix with my PR for End-to-end Masked Language Modeling with BERT port.

Thanks.

cc: @fchollet

fchollet commented 1 year ago

Thanks for the report. Yes I think we can modify the Constant initializer. Instead of casting in the constructor, you can just remove the float() in the constructor and then use ops.cast in call().