Open josharian opened 6 months ago
Thanks for the report. @mattdangerw, since the k/v shape are different from the q shape in 2b model, we might want to change the sharding spec for that, eg we could make it (None, data, None) since the first dim is always 1.
(None, data, None)
I am new to this, so definitely don't listen to me too much...but for folks like me struggling to squish this onto consumer GPUs, it'd be nice to have some model parallelism everywhere.
Thanks @josharian. Finally getting around to this -- sorry for the delay!
I think the issue we have here is we have all of multi-head attention, multi-query attention, and now I think grouped-query attention (with Gemma 2) in the same Gemma architecture. To me that kinda suggests we have the wrong signature here; we need the model config to create the map. Instead of:
layout_map = keras_nlp.models.GemmaCausalLM.get_layout_map(mesh)
distribution = keras.distribution.ModelParallel(mesh, layout_map)
with distribution.scope():
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset("./path_to_preset")
We might need:
layout_map = keras_nlp.models.GemmaCausalLM.get_layout_map("./path_to_preset", mesh)
distribution = keras.distribution.ModelParallel(mesh, layout_map)
with distribution.scope():
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset("./path_to_preset")
Or maybe there's a better API we could have. The order of operations get's kinda awkward here. You need to create the layout map before you create the model, but you need the config of the model before you create the layout map.
One alternative:
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), devices=devices)
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset(
"./path_to_preset",
device_mesh=device_mesh,
)
And we actually enter into a ModelParallel
device scope for you inside the from_preset
call.
So either:
ModelParallel
scope, where you control everything.If you are creating your own model from scratch (via a direct construct call), you have to do the former, since we don't know the correct layout map to create.
@fchollet @martin-gorner any thoughts on this and the proposal in the last comment?
Describe the bug
When attempting to shard a
gemma_2b_en
model across two (consumer-grade) GPUs, I get:The problem is the attention key/value kernels.
gemma_2b_en
decoder layer shapes:gemma_7b_en
decoder layer shapes:Observe that the leading dimension of
decoder_block.*attention.*(key|value).*kernel
is divisible by 2/4/8/16 ingemma_7b_en
but not ingemma_2b_en
.Additional context
This was introduced in https://github.com/keras-team/keras-nlp/pull/1491.
layout_map["decoder_block.*attention.*(query|key|value).*kernel"]
was changed from(None, None, model_dim)
to(model_dim, data_dim, None)
.cc @qlzh727 @mattdangerw
There are other issues filed around lora training and the layout_map regular expressions. This the unrelated; this reproduces without lora enabled.
Would you like to help us fix it?
Sure, although I don't know what the preferred fix is. One obvious choice would be to make this not a static method any more, so we can pick optimal layouts for each model size.