keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
766 stars 233 forks source link

Add auto variable sharding for all backbones/tasks #1689

Open mattdangerw opened 2 months ago

mattdangerw commented 2 months ago

We want model parallelism to be easy to use across the library. At a high level, a user should express their hardware, and (possibly) desired model parallel vs data parallel split for the device grid.

Currently, we have a auto layer helper for Gemma here, but it is not a salable design. The correct layout map will depend on the config of the model. E.g. you need to shard a Gemma model with multi-head-attention differently then multi-query-attention.

I think there's two main direction we can go with the API. 1) Write our own manual sharing for a model given the config for a model. Do this for all models (most will have the same recipe, especially for our transformer models). 2) Use some form of autosharding functionality in Jax, or add a autosharding API to Keras. In this case, we will not need to write the sharding recipes ourselves per model.

One potential high-level API would be to directly take in a device mesh when constructing the model. For both 1) and 2), we could support an API something like this...

device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), devices=devices)
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset(
    "gemma_2b_en",
    device_mesh=device_mesh,
)

For 1) we would need to enter into a LayoutMap scope after loading the config for a model but before loading the weights. For 2) it would depend on the details of the autosharding API we use.

mattdangerw commented 2 months ago

We should also keep the docstring for the method on the Backbone base class. And factor out all the error checking somehow. That way the per model code here could be really minimal.