Closed EricLBuehler closed 4 months ago
Need to modify Gemma model implementation with:
if self.config.attn_logit_softcapping is not None: attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.config.attn_logit_softcapping
if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping
query_pre_attn_scalar
1/sqrt(head_dim)
Implemented in #490.
Need to modify Gemma model implementation with:
Changelist over original Gemma and status:
query_pre_attn_scalar
instead of1/sqrt(head_dim)
Links