Open mattdangerw opened 2 weeks ago
We want model parallelism to be easy to use across the library.
We should add a get_layout_map() implementation to all backbones. This should be mostly copy paste from the Gemma version, since all transformers are pretty much the same weight structure. https://github.com/keras-team/keras-nlp/blob/a00efc29c24eb15c735c220090775a72dcdd21c8/keras_nlp/src/models/gemma/gemma_backbone.py#L227-L315
get_layout_map()
We should also keep the docstring for the method on the Backbone base class. And factor out all the error checking somehow. That way the per model code here could be really minimal.
Backbone
We want model parallelism to be easy to use across the library.
We should add a
get_layout_map()
implementation to all backbones. This should be mostly copy paste from the Gemma version, since all transformers are pretty much the same weight structure. https://github.com/keras-team/keras-nlp/blob/a00efc29c24eb15c735c220090775a72dcdd21c8/keras_nlp/src/models/gemma/gemma_backbone.py#L227-L315