keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
743 stars 219 forks source link

Add `get_layout_map()` for all backbones #1689

Open mattdangerw opened 2 weeks ago

mattdangerw commented 2 weeks ago

We want model parallelism to be easy to use across the library.

We should add a get_layout_map() implementation to all backbones. This should be mostly copy paste from the Gemma version, since all transformers are pretty much the same weight structure. https://github.com/keras-team/keras-nlp/blob/a00efc29c24eb15c735c220090775a72dcdd21c8/keras_nlp/src/models/gemma/gemma_backbone.py#L227-L315

mattdangerw commented 2 weeks ago

We should also keep the docstring for the method on the Backbone base class. And factor out all the error checking somehow. That way the per model code here could be really minimal.