elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.26k stars 90 forks source link

Add Gemma attention head size #364

Closed cmeon closed 3 months ago

cmeon commented 3 months ago

This fix adds attention_head_size which is different for Gemma, equal to 256.

Without this, during module building, the runtime would ignore the mismatched shapes and result in unintelligible results.

00:15:21.191 [debug] the following parameters were ignored, because of non-matching shape:

  * decoder.blocks.24.self_attention.value.kernel (expected {3072, 3072}, got: {3072, 4096})
  * decoder.blocks.18.self_attention.value.kernel (expected {3072, 3072}, got: {3072, 4096})
  * decoder.blocks.18.self_attention.output.kernel (expected {3072, 3072}, got: {4096, 3072})
  ...

See: