ridgerchu / matmulfreellm

Implementation for MatMul-free LM.
Apache License 2.0
2.92k stars 184 forks source link

Discrepency in code and paper related to HGRNBitAttention #37

Open loki-r opened 3 months ago

loki-r commented 3 months ago

The equations in the paper and the code don't match for the last equation.

The figure shows the last output equation as image

But based on the current code. It looks like this is the execution

$o_t^{'} = RMSNORM(g_t) * \sigma(h_t)$

instead of

$o_t^{'} = RMSNORM(h_t) * \sigma(g_t)$

Seems like this is fixed in recent commit HGRN - flash-linear-attention repository

        last_state = (recurrent_state,)
        past_key_values.update(last_state, self.layer_idx, i.shape[2])

-       o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))
+       o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)'), self.g_proj(hidden_states))
        o = self.o_proj(o)

        return o, None, past_key_values

Existing code path of current repository :

Are the results with the inverted equation or with the fixed equation ?

ridgerchu commented 3 months ago

Hi, I think it is a bug, due to the HGRN api modifications. the sigma should be applied to g_t for better performance, but now it is applied to h_t. and our pre-trained model also still using sigma h_t... we will fix this problem in arxiv soon, and we believe that if applied to g_t it would be better performance compared with our current version.