Open Subuday opened 3 months ago
Hey, good question :)
The 8
is sqrt(64)
which is the default head_width
(so by default, the whole expression is going to be query_mult * 1
). Similar to the 3
in the base_width
calculation.
The "maximum update parametrization" from the Tensor Program V paper requires us to assume some "base" widths which we can use to determine if we need to adjust the learning rate, initialization stddev, etc. The choice is completely arbitrary since only thing that matters is the ratio between this base width and the lr/init_stddev hyperparameters (which we tune in the end).
Could you provide insists please about 8 in this line of code:
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)