keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
734 stars 216 forks source link

GemmaBackbone.get_layout_map broken for gemma_2b_en #1613

Open josharian opened 2 months ago

josharian commented 2 months ago

Describe the bug

When attempting to shard a gemma_2b_en model across two (consumer-grade) GPUs, I get:

ValueError: One of device_put args was given the sharding of NamedSharding(mesh=Mesh('data': 1, 'model': 2), spec=PartitionSpec('model', 'data', None)), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 2048, 256))

The problem is the attention key/value kernels. gemma_2b_en decoder layer shapes:

decoder_block_0/pre_attention_norm/scale                    (2048,)         
decoder_block_0/attention/query/kernel                      (8, 2048, 256)  
decoder_block_0/attention/key/kernel                        (1, 2048, 256)  
decoder_block_0/attention/value/kernel                      (1, 2048, 256)  
decoder_block_0/attention/attention_output/kernel           (8, 256, 2048)  
decoder_block_0/pre_ffw_norm/scale                          (2048,)         
decoder_block_0/ffw_gating/kernel                           (2048, 16384)   
decoder_block_0/ffw_gating_2/kernel                         (2048, 16384)   
decoder_block_0/ffw_linear/kernel                           (16384, 2048)   

gemma_7b_en decoder layer shapes:

decoder_block_0/pre_attention_norm/scale                    (3072,)         
decoder_block_0/attention/query/kernel                      (16, 3072, 256) 
decoder_block_0/attention/key/kernel                        (16, 3072, 256) 
decoder_block_0/attention/value/kernel                      (16, 3072, 256) 
decoder_block_0/attention/attention_output/kernel           (16, 256, 3072) 
decoder_block_0/pre_ffw_norm/scale                          (3072,)         
decoder_block_0/ffw_gating/kernel                           (3072, 24576)   
decoder_block_0/ffw_gating_2/kernel                         (3072, 24576)   
decoder_block_0/ffw_linear/kernel                           (24576, 3072)   

Observe that the leading dimension of decoder_block.*attention.*(key|value).*kernel is divisible by 2/4/8/16 in gemma_7b_en but not in gemma_2b_en.

Additional context

This was introduced in https://github.com/keras-team/keras-nlp/pull/1491. layout_map["decoder_block.*attention.*(query|key|value).*kernel"] was changed from (None, None, model_dim) to (model_dim, data_dim, None).

cc @qlzh727 @mattdangerw

There are other issues filed around lora training and the layout_map regular expressions. This the unrelated; this reproduces without lora enabled.

Would you like to help us fix it?

Sure, although I don't know what the preferred fix is. One obvious choice would be to make this not a static method any more, so we can pick optimal layouts for each model size.

qlzh727 commented 2 months ago

Thanks for the report. @mattdangerw, since the k/v shape are different from the q shape in 2b model, we might want to change the sharding spec for that, eg we could make it (None, data, None) since the first dim is always 1.

josharian commented 2 months ago

(None, data, None)

I am new to this, so definitely don't listen to me too much...but for folks like me struggling to squish this onto consumer GPUs, it'd be nice to have some model parallelism everywhere.