keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
758 stars 227 forks source link

Add task base classes; support out of tree library extensions #1517

Closed mattdangerw closed 5 months ago

mattdangerw commented 5 months ago

This PR grew as I was writing it, and now tries out a number of new things:

  1. Exposed base classes. Sets us on a path for better documentation, a more "introspectable" library, and allow sub-classing.
  2. Enable from_preset() on base classes for any subclass. This gives us similar functionality to "auto classes" in huggingface, without the extra overhead of needing a new symbol.
  3. An ability to register new tasks/backbones/tokenizers from out of tree code with keras.saving.register_keras_serializable().

Colab playground: https://colab.research.google.com/gist/mattdangerw/da885f050fa8baef9b4f9a4ec68d6567/kerasnlp-base-classes.ipynb

fchollet commented 5 months ago

Exposed base classes. Sets us on a path for better documentation, a more "introspectable" library, and allow sub-classing.

Sure, that's definitely a great addition.

Enable from_preset() on base classes for any subclass. This gives us similar functionality to "auto classes" in huggingface, without the extra overhead of needing a new symbol.

Makes perfect sense as well.

An ability to register new tasks/backbones/tokenizers from out of tree code with keras_nlp.utils.register_preset_class().

Why does this need to be special cased? Wouldn't register_keras_serializable just work?

ValueError: Preset has type MistralTokenizer which is not a a subclass or equal to calling class BertTokenizer. Call from_preset directly on MistralTokenizer instead.

To help with debugging, provide all and any bits of relevant information, e.g. Preset 'kaggle://keras/mistral/keras/mistral_7b_en' has type MistralTokenizer ... -- helps parsing error messages. Otherwise it takes a few seconds of pondering to figure out what "preset" refers to.

Likewise "Call MistralTokenizer.from_preset('kaggle://keras/mistral/keras/mistral_7b_en' ) directly instead" -- now you just need to copy/paste that code.

mattdangerw commented 5 months ago

Why does this need to be special cased? Wouldn't register_keras_serializable just work?

Good question. In short I think it could with keras.saving.get_custom_objects(). But that would mean occasionally needing to walk the full list of custom objects to discover what is registered. E.g. if we are looking for a classifier that can load a FooBarBackbone, we would need to walk all custom objects until we find a classifier subclass where classifier.backbone_cls == FooBarBackbone.

But in practice that's probably negligible compared to the real slowdowns here (allocating memory for a llm, downloading weights, reading weights for a file, etc). I will play around with this.

mattdangerw commented 5 months ago

Reworked out subclassing experience to be as smooth as possible.

mattdangerw commented 5 months ago

Ok this turned into quite the project, but think this is ready for review.

Most of the diff is from trying to clean up our subclass experience by removing cruft of the subclasses themselves.