huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.02k stars 27.02k forks source link

Creating Flax VisualBert based on Flax Bert #12504

Closed gchhablani closed 3 years ago

gchhablani commented 3 years ago

I am using the VisualBert model and the FlaxBert model to create a model similar to VisualBert in Flax (which will use ViT instead of Detectron, hence the name). Here are the embeddings:

class FlaxViTBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""
    config: BertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        self.word_embeddings = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        self.position_embeddings = nn.Embed(
            self.config.max_position_embeddings,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        self.token_type_embeddings = nn.Embed(
            self.config.type_vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        self.visual_projection = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype))
        self.visual_position_embeddings = nn.Embed(
            self.config.max_position_embeddings,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        self.visual_token_type_embeddings = nn.Embed(
            self.config.type_vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
    def __call__(self, input_ids, token_type_ids, position_ids, visual_inputs_embeds, visual_token_type_ids, visual_position_ids, deterministic: bool = True):
        # Embed
        inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
        position_embeds = self.position_embeddings(position_ids.astype("i4"))
        token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
        # Sum all embeddings
        word_embeddings = inputs_embeds + token_type_embeddings + position_embeds
        # Visual Embed
        visual_inputs_embeds = self.visual_projection(visual_inputs_embeds)
        visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids.astype("i4"))
        visual_position_embeds = self.visual_position_embeddings(visual_position_ids.astype("i4"))
        # Sum all visual embeddings
        visual_embeddings = visual_inputs_embeds + visual_token_type_embeddings + visual_position_embeds
        # Concat
        hidden_states = jnp.concatenate((word_embeddings, visual_embeddings),axis=1)
        # Layer Norm
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states

These embeddings work fine. I generate parameters using a random key and then apply those parameters. Then, I create the model like so:

class FlaxViTBertModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        self.embeddings = FlaxViTBertEmbeddings(self.config, dtype=self.dtype)
        self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
        self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, visual_input_shape) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        attention_mask = jnp.ones_like(input_ids)

        visual_inputs_embeds = jnp.random(visual_input_shape),
        visual_attention_mask = jnp.ones(visual_input_shape[:-1])
        visual_token_type_ids = jnp.ones(visual_input_shape[:-1])
        visual_position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(visual_input_shape).shape[-2]), visual_input_shape[:-1])

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, visual_inputs_embeds,
        visual_attention_mask,
        visual_token_type_ids, 
        visual_position_ids, return_dict=False)[
            "params"
        ]

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        visual_inputs_embeds,
        visual_attention_mask,
        visual_token_type_ids, 
        visual_position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        hidden_states = self.embeddings(
            input_ids, token_type_ids, position_ids, visual_input_embeds, visual_token_type_ids, visual_position_ids, deterministic=deterministic
        )
        outputs = self.encoder(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        if not return_dict:
            # if pooled is None, don't return it
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class FlaxViTBertModel(FlaxBertPreTrainedModel):
    module_class = FlaxViTBertModule

When I try this:

flax_model = FlaxViTBertModel.from_pretrained('bert-base-uncased')

I get the following error:

TypeError: __call__() missing 4 required positional arguments: 'visual_inputs_embeds', 'visual_attention_mask', 'visual_token_type_ids', and 'visual_position_ids'

I believe the issue is because FlaxBertPreTrainedModel only takes in input_shape. But FlaxBertPreTrainedModel in turn calls FlaxPreTrainedModel's __init__(), which again only has input_shape only.

What would be an elegant way to deal with this? How do I create few random weights, and few initialized from the pre-trained checkpoint?

EDIT: I am aware there will be a shape mismatch in the model. I will fix it when it comes to that.

gchhablani commented 3 years ago

I tried another way (by modifying the pre-trained class):

class FlaxViTBertPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "vitbert"
    module_class: nn.Module = None

    def __init__(
        self, config: BertConfig, input_shape: Tuple = ((1, 1),(1,1,1)), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
    ):
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
        # init input tensors
        textual_input_shape = input_shape[0]
        input_ids = jnp.zeros(textual_input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape)
        attention_mask = jnp.ones_like(input_ids)

        visual_input_shape = input_shape[1]
        visual_inputs_embeds = jax.random.normal(visual_input_shape),
        visual_attention_mask = jnp.ones(visual_input_shape[:-1])
        visual_token_type_ids = jnp.ones(visual_input_shape[:-1])
        visual_position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(visual_input_shape).shape[-2]), visual_input_shape[:-1])

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, visual_inputs_embeds,
        visual_attention_mask,
        visual_token_type_ids, 
        visual_position_ids, return_dict=False)[
            "params"
        ]

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        visual_inputs_embeds=None,
        visual_attention_mask=None,
        visual_token_type_ids=None, 
        visual_position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # init input tensors if not passed
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        if visual_token_type_ids is None:
            visual_token_type_ids = jnp.ones(visual_inputs_embeds.shape[:-1])

        if visual_position_ids is None:
            visual_position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(visual_input_embeds).shape[-2]), visual_inputs_embeds.shape[:-1])

        if visual_attention_mask is None:
            visual_attention_mask = jnp.ones(visual_inputs_embeds.shape[:-1])

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(visual_inputs_embeds, dtype=jnp.float32),
            jnp.array(visual_attention_mask, dtype="i4"),
            jnp.array(visual_token_type_ids, dtype="i4"),
            jnp.array(visual_position_ids, dtype="i4"),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

class FlaxViTBertModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        self.embeddings = FlaxViTBertEmbeddings(self.config, dtype=self.dtype)
        self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
        self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        visual_inputs_embeds,
        visual_attention_mask,
        visual_token_type_ids, 
        visual_position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        hidden_states = self.embeddings(
            input_ids, token_type_ids, position_ids, visual_input_embeds, visual_token_type_ids, visual_position_ids, deterministic=deterministic
        )

        combined_attention_mask = jnp.concatenate((attention_mask, visual_attention_mask), axis=1)

        outputs = self.encoder(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        if not return_dict:
            # if pooled is None, don't return it
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class FlaxViTBertModel(FlaxViTBertPreTrainedModel):
    module_class = FlaxViTBertModule

Now I get the following issue no trying:

flax_model = FlaxViTBertModel.from_pretrained('bert-base-multilingual-uncased')
TypeError: _random_bits got invalid prng key.

Any idea why this happens?

gchhablani commented 3 years ago

Nevermind, I forgot to pass the random key to normal method. I will update here when I am successful with the model.

gchhablani commented 3 years ago

I have updated the code based on the Hybrid CLIP example. But use FlaxViTModule inside FlaxViTBertEmbeddings. Now I get the following error:

AssertionError: A state dict must only have string keys.

Notebook: https://colab.research.google.com/drive/1mNzt4NRBpibJ_7U73Rj3sDPAkcMRPXTd?usp=sharing