coastalcph / lex-glue

LexGLUE: A Benchmark Dataset for Legal Language Understanding in English
180 stars 35 forks source link

Loading Hierarchical models #16

Closed Just-Strato closed 2 years ago

Just-Strato commented 2 years ago

Hi, i used the scripts and everything worked fine, i was able to train the models without any trouble. The results shown with the testing after training are also coherent.

But the issue (at the end of the message) occurred when i tried to load the model it order to test to predict other samples. It is not possible to load the model because there is a difference between the names of the layers expected and the layers in the file. As we can see in the error message (at the end), there are double occurences of "encoder" in some layer names of the saved file. When loading, the model does not use those layer names.

This problem happens with ECtHR (A & B) and Scotus tasks (maybe even others) with Bert models, it seems that the issue occurs when using hierarchical variant. When not using hierarchical, we dont have any problem to load the models after saving them. But the results are not as performant as they should be.

Do you have the same issue ? I am using Ubuntu 20.04 with python 3.8.

[WARNING|modeling_utils.py:1501] 2022-03-28 18:01:33,192 >> Some weights of the model checkpoint at /home/X/Xs/lex-glue/seed_1 were not used when initializing BertForSequenceClassification: ['bert.encoder.encoder.layer.4.attention.self.query.weight', 'bert.seg_encoder.layers.1.self_attn.out_proj.weight', 'bert.encoder.encoder.layer.8.attention.self.query.bias', 'bert.seg_encoder.layers.1.norm2.weight', 'bert.encoder.encoder.layer.10.output.dense.bias', 'bert.encoder.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.0.attention.self.key.weight', 'bert.encoder.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.seg_encoder.layers.1.self_attn.out_proj.bias', 'bert.seg_encoder.layers.0.norm1.bias', 'bert.encoder.encoder.layer.10.attention.self.query.bias', 'bert.encoder.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.encoder.layer.7.attention.self.value.weight', 'bert.seg_encoder.layers.1.norm2.bias', 'bert.encoder.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.embeddings.token_type_embeddings.weight', 'bert.encoder.embeddings.word_embeddings.weight', 'bert.encoder.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.encoder.layer.11.intermediate.dense.weight', 'bert.seg_encoder.layers.0.self_attn.out_proj.bias', 'bert.seg_encoder.layers.1.norm1.weight', 'bert.encoder.encoder.layer.10.output.dense.weight', 'bert.seg_encoder.layers.0.norm1.weight', 'bert.encoder.encoder.layer.8.attention.self.value.weight', 'bert.encoder.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.encoder.layer.6.attention.output.LayerNorm.bias', 'bert.encoder.embeddings.LayerNorm.weight', 'bert.encoder.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.9.output.dense.weight', 'bert.encoder.encoder.layer.6.output.dense.weight', 'bert.encoder.encoder.layer.2.output.dense.weight', 'bert.encoder.encoder.layer.11.output.LayerNorm.bias', 'bert.seg_encoder.layers.1.norm1.bias', 'bert.encoder.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.encoder.layer.11.attention.self.key.bias', 'bert.encoder.encoder.layer.3.attention.output.dense.bias', 'bert.seg_encoder.layers.0.linear2.weight', 'bert.encoder.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.3.attention.self.query.weight', 'bert.encoder.encoder.layer.3.output.dense.weight', 'bert.seg_encoder.norm.weight', 'bert.encoder.encoder.layer.8.output.dense.bias', 'bert.seg_encoder.layers.1.linear2.weight', 'bert.encoder.embeddings.position_ids', 'bert.encoder.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.encoder.layer.4.attention.self.key.weight', 'bert.encoder.encoder.layer.3.attention.self.query.bias', 'bert.encoder.encoder.layer.1.output.LayerNorm.bias', 'bert.encoder.encoder.layer.10.attention.self.key.weight', 'bert.encoder.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.encoder.layer.1.attention.self.query.bias', 'bert.encoder.encoder.layer.10.attention.self.value.weight', 'bert.encoder.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.encoder.layer.5.output.dense.bias', 'bert.encoder.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.2.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.encoder.layer.9.attention.self.key.weight', 'bert.encoder.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.encoder.layer.8.attention.self.key.bias', 'bert.encoder.encoder.layer.4.attention.self.value.weight', 'bert.encoder.encoder.layer.3.attention.self.value.bias', 'bert.encoder.encoder.layer.9.attention.self.value.bias', 'bert.encoder.encoder.layer.9.attention.self.key.bias', 'bert.encoder.encoder.layer.0.attention.self.value.weight', 'bert.encoder.encoder.layer.7.output.dense.weight', 'bert.encoder.encoder.layer.7.attention.self.query.weight', 'bert.seg_encoder.layers.0.self_attn.in_proj_weight', 'bert.encoder.encoder.layer.6.attention.self.value.weight', 'bert.encoder.encoder.layer.11.attention.self.query.bias', 'bert.seg_encoder.layers.0.self_attn.out_proj.weight', 'bert.encoder.encoder.layer.2.output.dense.bias', 'bert.seg_encoder.layers.1.self_attn.in_proj_weight', 'bert.seg_encoder.layers.1.linear2.bias', 'bert.encoder.encoder.layer.0.attention.self.key.bias', 'bert.encoder.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.encoder.layer.4.attention.self.value.bias', 'bert.seg_encoder.layers.0.self_attn.in_proj_bias', 'bert.encoder.encoder.layer.6.attention.self.query.bias', 'bert.encoder.embeddings.position_embeddings.weight', 'bert.encoder.encoder.layer.8.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.pooler.dense.weight', 'bert.encoder.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.encoder.layer.1.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.5.attention.self.query.weight', 'bert.encoder.encoder.layer.1.attention.output.dense.weight', 'bert.encoder.encoder.layer.1.output.dense.weight', 'bert.encoder.encoder.layer.0.output.dense.weight', 'bert.encoder.encoder.layer.3.attention.self.key.weight', 'bert.encoder.encoder.layer.2.attention.self.value.weight', 'bert.encoder.encoder.layer.5.attention.self.query.bias', 'bert.encoder.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.encoder.layer.9.attention.self.query.bias', 'bert.encoder.encoder.layer.1.attention.self.key.bias', 'bert.encoder.encoder.layer.7.attention.self.key.bias', 'bert.encoder.encoder.layer.11.attention.self.value.bias', 'bert.encoder.encoder.layer.1.attention.self.query.weight', 'bert.encoder.encoder.layer.1.attention.output.dense.bias', 'bert.encoder.encoder.layer.9.attention.self.query.weight', 'bert.encoder.encoder.layer.5.output.dense.weight', 'bert.encoder.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.1.attention.self.value.bias', 'bert.seg_encoder.layers.1.self_attn.in_proj_bias', 'bert.encoder.encoder.layer.3.attention.self.value.weight', 'bert.encoder.encoder.layer.11.output.dense.weight', 'bert.encoder.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.encoder.layer.0.output.LayerNorm.bias', 'bert.seg_encoder.layers.0.linear1.bias', 'bert.encoder.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.encoder.layer.0.attention.self.query.weight', 'bert.encoder.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.embeddings.LayerNorm.bias', 'bert.encoder.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.encoder.layer.6.attention.self.key.bias', 'bert.encoder.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.encoder.layer.8.attention.self.value.bias', 'bert.encoder.encoder.layer.11.output.dense.bias', 'bert.encoder.encoder.layer.11.intermediate.dense.bias', 'bert.seg_encoder.norm.bias', 'bert.encoder.encoder.layer.1.attention.self.value.weight', 'bert.encoder.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.encoder.layer.7.attention.self.query.bias', 'bert.encoder.encoder.layer.10.attention.self.query.weight', 'bert.encoder.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.seg_encoder.layers.1.linear1.weight', 'bert.encoder.encoder.layer.0.attention.self.value.bias', 'bert.encoder.encoder.layer.3.attention.self.key.bias', 'bert.encoder.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.encoder.layer.2.attention.output.dense.weight', 'bert.encoder.encoder.layer.7.attention.self.key.weight', 'bert.encoder.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.1.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.4.output.dense.weight', 'bert.encoder.encoder.layer.7.attention.self.value.bias', 'bert.encoder.encoder.layer.7.output.dense.bias', 'bert.encoder.encoder.layer.5.attention.self.value.bias', 'bert.encoder.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.encoder.layer.10.intermediate.dense.bias', 'bert.seg_encoder.layers.0.linear2.bias', 'bert.seg_encoder.layers.0.linear1.weight', 'bert.encoder.encoder.layer.11.attention.self.query.weight', 'bert.encoder.encoder.layer.2.attention.self.query.weight', 'bert.encoder.encoder.layer.5.attention.self.value.weight', 'bert.encoder.encoder.layer.4.output.dense.bias', 'bert.encoder.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.11.attention.self.value.weight', 'bert.encoder.encoder.layer.5.attention.self.key.bias', 'bert.encoder.encoder.layer.11.attention.self.key.weight', 'bert.encoder.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.encoder.layer.1.output.dense.bias', 'bert.encoder.encoder.layer.2.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.encoder.layer.6.attention.self.key.weight', 'bert.encoder.encoder.layer.2.attention.output.dense.bias', 'bert.encoder.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.2.attention.self.key.weight', 'bert.encoder.pooler.dense.bias', 'bert.encoder.encoder.layer.2.attention.self.query.bias', 'bert.encoder.encoder.layer.0.output.dense.bias', 'bert.encoder.encoder.layer.6.attention.self.query.weight', 'bert.encoder.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.encoder.layer.0.attention.self.query.bias', 'bert.encoder.encoder.layer.5.output.LayerNorm.weight', 'bert.encoder.encoder.layer.9.attention.output.dense.weight', 'bert.encoder.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.encoder.layer.8.attention.self.query.weight', 'bert.encoder.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.encoder.layer.8.output.dense.weight', 'bert.encoder.encoder.layer.10.attention.self.value.bias', 'bert.encoder.encoder.layer.3.attention.output.dense.weight', 'bert.seg_encoder.layers.0.norm2.bias', 'bert.encoder.encoder.layer.9.attention.self.value.weight', 'bert.encoder.encoder.layer.8.attention.self.key.weight', 'bert.encoder.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.encoder.layer.9.output.dense.bias', 'bert.encoder.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.encoder.layer.1.output.LayerNorm.weight', 'bert.encoder.encoder.layer.6.output.dense.bias', 'bert.encoder.encoder.layer.1.attention.self.key.weight', 'bert.encoder.encoder.layer.5.attention.output.dense.bias', 'bert.seg_pos_embeddings.weight', 'bert.encoder.encoder.layer.2.attention.self.key.bias', 'bert.encoder.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.encoder.layer.2.intermediate.dense.bias', 'bert.encoder.encoder.layer.3.output.dense.bias', 'bert.encoder.encoder.layer.10.attention.self.key.bias', 'bert.encoder.encoder.layer.1.intermediate.dense.bias', 'bert.encoder.encoder.layer.9.output.LayerNorm.bias', 'bert.seg_encoder.layers.0.norm2.weight', 'bert.encoder.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.encoder.layer.4.attention.self.query.bias', 'bert.encoder.encoder.layer.5.attention.self.key.weight', 'bert.encoder.encoder.layer.6.attention.self.value.bias', 'bert.seg_encoder.layers.1.linear1.bias', 'bert.encoder.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.encoder.layer.2.attention.self.value.bias', 'bert.encoder.encoder.layer.4.attention.self.key.bias', 'bert.encoder.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.encoder.layer.7.attention.output.LayerNorm.bias'] This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). [WARNING|modeling_utils.py:1512] 2022-03-28 18:01:33,192 >> Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/X/X/lex-glue/seed_1 and are newly initialized: ['bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.2.output.dense.bias', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.9.attention.output.dense.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.2.attention.output.dense.bias', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.1.attention.output.dense.weight', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.5.output.LayerNorm.weight', 'bert.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.2.attention.output.dense.weight', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.1.attention.self.value.weight', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.0.output.LayerNorm.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.2.output.dense.weight', 'bert.encoder.layer.2.attention.self.query.bias', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.2.attention.self.query.weight', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.pooler.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.1.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.1.output.dense.weight', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.4.attention.self.value.bias', 'bert.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.2.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.2.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.1.attention.self.key.bias', 'bert.encoder.layer.1.output.LayerNorm.bias', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.1.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.1.attention.self.query.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.embeddings.position_embeddings.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.key.bias', 'bert.encoder.layer.4.attention.self.value.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.1.attention.self.key.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.1.output.dense.bias', 'bert.encoder.layer.1.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.1.attention.output.dense.bias', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.2.attention.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.11.attention.self.query.bias', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.1.attention.self.value.bias', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.2.intermediate.dense.bias', 'bert.embeddings.word_embeddings.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.embeddings.LayerNorm.weight', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.value.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.1.attention.output.LayerNorm.bias', 'bert.encoder.layer.1.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.pooler.dense.bias'] <

iliaschalkidis commented 2 years ago

Hi @Just-Strato,

Thanks for raising this issue. It's true that the current implementation does not support post-hoc evaluation for hierarchical models, i.e., train model and then recall script to evaluate only. The saved parameters (pytorch_model.bin) include everything, but the current code uses the HF (from_pretrained), which only considers the initial BERT encoder parameters.

For this reason, I recently uploaded this demo script (https://github.com/coastalcph/lex-glue/blob/main/utils/load_hierbert.py), which loads the mode using torch.load(). You need to load, and pre-process the dataset in the same fashion with the standard training scripts (e.g., /experiments/ecthr.py), re-load model similarly to this script and call evaluate...

I think the easiest (fastest) way to do that is by modifying the /experiments/ecthr.py code right after this part of the code, where the hierarchical model is initialized : https://github.com/coastalcph/lex-glue/blob/22b651301f89b963d2985bdc0b844fd7197c6457/experiments/ecthr.py#L286

Then you can re-load model as presented in the demo code, something along these lines:

if not training_args.do_train and model_args.hierarchical:
     # Load Hierarchical BERT model
    model_state_dict = torch.load(f'{training_args.output_dir}/pytorch_model.bin', map_location=torch.device('cpu'))
    model.load_state_dict(model_state_dict)

So, the model will use all saved parameters from the pytorch_model.bin and you'll be able to re-evaluate model's perfomance.

Would you like to give a try? I am for reviewing your PR, merging to the codebase, and give you credits šŸ˜„

Just-Strato commented 2 years ago

Thank you for your time @iliaschalkidis. Okey i see, i will give it a try and do a PR if i manage to make it work ! šŸ˜„