castorini / DeeBERT

DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference
Apache License 2.0
151 stars 23 forks source link

Entropy calculation can be modified #17

Closed sbwww closed 2 years ago

sbwww commented 2 years ago

When calculating entropy, dim=1 is better to be replaced with dim=-1, since num_labels is the last dimension of logits but not always the 2nd dimension (e.g., in Token Classification, num_labels is the 3rd dimension).

And what about using softmax directly?

def entropy(x):  # my attempt
    x = torch.softmax(x, dim=-1)               # softmax normalized prob distribution
    return -torch.sum(x*torch.log(x), dim=-1)  # entropy calculation on probs: -\sum(p \ln(p))
def entropy(x):  # original code
    exp_x = torch.exp(x)
    A = torch.sum(exp_x, dim=1)    # sum of exp(x_i)
    B = torch.sum(x*exp_x, dim=1)  # sum of x_i * exp(x_i)
    return torch.log(A) - B/A

It seems that the outputs are no different, and using softmax is generally more efficient than manual calculation.

Since DeeBERT is proposed for efficiency, it's better to choose the efficient way.

Mind if I open a PR for this? @ji-xin

ji-xin commented 2 years ago

Thanks for the suggestion! Sure, please feel free to open a PR for it.