kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

Is the treatment of embedding bias in to_hf_weights.py correct? #234

Closed xiaoda99 closed 2 years ago

xiaoda99 commented 2 years ago

Hello,

mesh-transformer-jax uses a linear layer with bias for embedding while hf model has no wte.embedding.bias. The code below shows how to handle this problem: https://github.com/kingoflolz/mesh-transformer-jax/blob/master/to_hf_weights.py#L386-L397

I think this treatment is incorrect. IMO, there's no way of absorbing a linear layer's bias into its weights. if we set w' = w + b, then y = x w + b y' = x w' = x (w + b) = x w + x * b The only case y == y' is when b == 0, which is generally not true.

vfbd commented 2 years ago

You are right that in the general case, y != y'. However, in this code, x is the result of a one-hot encoding at: https://github.com/kingoflolz/mesh-transformer-jax/blob/0a75ca9370576ad9d247facf6cb8e9699300e690/mesh_transformer/layers.py#L190

This means that x is always a matrix with only 0's and 1's such that there is at most one 1 in each row. Therefore, every row of x * w is either a row from w or a row filled with 0's.

Furthermore, as long as all of the token IDs are nonnegative integers less than the vocabulary size, (the non-parallelized version of) x won't have any rows with all 0's so every row of x w is a row from w. Hence, (x w) + b = x * (w + b).

xiaoda99 commented 2 years ago

got it. thx!