castorini / DeeBERT

DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference
Apache License 2.0
151 stars 23 forks source link

Confusing code in class BertHighway in modeling_highway_bert.py #16

Closed sbwww closed 2 years ago

sbwww commented 2 years ago

Greetings, DeeBERT is really a crucial and easy-to-understand achievement in BERT inference acceleration.

However, in transformers/modeling_highway_bert.py, the forward function of class BertHighway is a bit confusing. Your original code is as follows

def forward(self, encoder_outputs):
    # Pooler
    pooler_input = encoder_outputs[0]
    pooler_output = self.pooler(pooler_input)
    # "return" pooler_output

    # BertModel
    bmodel_output = (pooler_input, pooler_output) +encoder_outputs[1:]
    # "return" bodel_output

    # Dropout and classification
    pooled_output = bmodel_output[1]

    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    return logits, pooled_output

I am not quite sure about introducing bmodel_output for it's not actually used.

It seems that pooler_input is sequence_output, and pooled_output is equivalent to pooler_output.

Is there any trick that should be noticed? Maybe the comments starting with "return" can be updated for more details.

ji-xin commented 2 years ago

You are right that bmodel_output is not actually used and all we need is pooler_output. I wrote it in this way mainly because I wanted them to align with the output of BertPooler, BertModel, etc., at the time of implementation.