Closed gchhablani closed 3 years ago
I tried another way (by modifying the pre-trained class):
class FlaxViTBertPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "vitbert"
module_class: nn.Module = None
def __init__(
self, config: BertConfig, input_shape: Tuple = ((1, 1),(1,1,1)), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
textual_input_shape = input_shape[0]
input_ids = jnp.zeros(textual_input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape)
attention_mask = jnp.ones_like(input_ids)
visual_input_shape = input_shape[1]
visual_inputs_embeds = jax.random.normal(visual_input_shape),
visual_attention_mask = jnp.ones(visual_input_shape[:-1])
visual_token_type_ids = jnp.ones(visual_input_shape[:-1])
visual_position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(visual_input_shape).shape[-2]), visual_input_shape[:-1])
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, visual_inputs_embeds,
visual_attention_mask,
visual_token_type_ids,
visual_position_ids, return_dict=False)[
"params"
]
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
visual_inputs_embeds=None,
visual_attention_mask=None,
visual_token_type_ids=None,
visual_position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if visual_token_type_ids is None:
visual_token_type_ids = jnp.ones(visual_inputs_embeds.shape[:-1])
if visual_position_ids is None:
visual_position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(visual_input_embeds).shape[-2]), visual_inputs_embeds.shape[:-1])
if visual_attention_mask is None:
visual_attention_mask = jnp.ones(visual_inputs_embeds.shape[:-1])
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(visual_inputs_embeds, dtype=jnp.float32),
jnp.array(visual_attention_mask, dtype="i4"),
jnp.array(visual_token_type_ids, dtype="i4"),
jnp.array(visual_position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxViTBertModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxViTBertEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
visual_inputs_embeds,
visual_attention_mask,
visual_token_type_ids,
visual_position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, visual_input_embeds, visual_token_type_ids, visual_position_ids, deterministic=deterministic
)
combined_attention_mask = jnp.concatenate((attention_mask, visual_attention_mask), axis=1)
outputs = self.encoder(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FlaxViTBertModel(FlaxViTBertPreTrainedModel):
module_class = FlaxViTBertModule
Now I get the following issue no trying:
flax_model = FlaxViTBertModel.from_pretrained('bert-base-multilingual-uncased')
TypeError: _random_bits got invalid prng key.
Any idea why this happens?
Nevermind, I forgot to pass the random key to normal
method. I will update here when I am successful with the model.
I have updated the code based on the Hybrid CLIP example. But use FlaxViTModule
inside FlaxViTBertEmbeddings
. Now I get the following error:
AssertionError: A state dict must only have string keys.
Notebook: https://colab.research.google.com/drive/1mNzt4NRBpibJ_7U73Rj3sDPAkcMRPXTd?usp=sharing
I am using the VisualBert model and the FlaxBert model to create a model similar to VisualBert in Flax (which will use ViT instead of Detectron, hence the name). Here are the embeddings:
These embeddings work fine. I generate parameters using a random key and then apply those parameters. Then, I create the model like so:
When I try this:
I get the following error:
I believe the issue is because
FlaxBertPreTrainedModel
only takes ininput_shape
. ButFlaxBertPreTrainedModel
in turn callsFlaxPreTrainedModel
's__init__()
, which again only hasinput_shape
only.What would be an elegant way to deal with this? How do I create few random weights, and few initialized from the pre-trained checkpoint?
EDIT: I am aware there will be a shape mismatch in the model. I will fix it when it comes to that.