databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Why the second matrix of the mlp layer has the same shape of the first one? #81

Open gouchangjiang opened 6 months ago

gouchangjiang commented 6 months ago

It's more a question than an issue. The tensor w2 of class SparseMLP has the same shape as the w1, is it because of the DSD operation? like, it requires the two matrices to have the same shape? Usually, the shape of w2 is the transpose of w1, e.g. MLP of GPT2. Thanks in advance.

tgale96 commented 6 months ago

Hi! Excellent question!

We get significantly better performance with SparseMLP by using weight data layouts where the inner-most dimension is hidden_size - likely because of cache effects, although I haven't looked into it more deeply than observing the difference :)