erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
191 stars 23 forks source link
easydel flax gpt jax machine-learning mojo nlp optax pytorch transformers

EasyDeL 🔮

Key Features | Latest Updates | Vision | Quick Start | Reference docs | License

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on Jax/Flax. It provides convenient and effective solutions for training and serving Flax/Jax models on TPU/GPU at scale.

Key Features

Fully Customizable and Hackable 🛠️

EasyDeL stands out by providing unparalleled flexibility and transparency:

With EasyDeL, you're not constrained by rigid frameworks. Instead, you have a flexible, powerful toolkit that adapts to your needs, no matter how unique or specialized they may be. Whether you're conducting cutting-edge research or building production-ready ML systems, EasyDeL provides the freedom to innovate without limitations.

Advanced Customization and Optimization 🔧

EasyDeL provides unparalleled flexibility in customizing and optimizing your models:

This level of customization allows you to squeeze every ounce of performance from your hardware while tailoring the model behavior to your exact requirements.

Future Updates and Vision 🚀

EasyDeL is constantly evolving to meet the needs of the machine learning community. In upcoming updates, we plan to introduce:

Why Choose EasyDeL?

  1. Flexibility: EasyDeL offers a modular design that allows researchers and developers to easily mix and match components, experiment with different architectures (including Transformers, Mamba, RWKV, and ...), and adapt models to specific use cases.

  2. Performance: Leveraging the power of JAX and Flax, EasyDeL provides high-performance implementations of state-of-the-art models and training techniques, optimized for both TPUs and GPUs.

  3. Scalability: From small experiments to large-scale model training, EasyDeL provides tools and optimizations to efficiently scale your models and workflows.

  4. Ease of Use: Despite its powerful features, EasyDeL maintains an intuitive API, making it accessible for both beginners and experienced practitioners.

  5. Cutting-Edge Research: quickly implementing the latest advancements in model architectures, training techniques, and optimization methods.

Quick Start

Installation

pip install easydel

Testing Attention Mechanisms

import easydel as ed
ed.FlexibleAttentionModule.test_attentions()

Documentation 💫

Comprehensive documentation and examples are available at EasyDeL Documentation.

Here's an improved version of your latest updates:

Latest Updates 🔥

Key Components

GenerationPipeline

The GenerationPipeline class provides a streamlined interface for text generation using pre-trained language models within the JAX framework.

import easydel as ed
from transformers import AutoTokenizer

model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

pipeline = ed.GenerationPipeline(model=model, params=params, tokenizer=tokenizer)

ApiEngine

ApiEngine is a Serve API Engine for production purposes, providing a stable and efficient API.

import easydel as ed

pipeline = ed.ChatPipeline(...)
engine = ed.ApiEngine(pipeline=pipeline, hostname="0.0.0.0", port=11550)
engine.fire()

EasyDeLState

EasyDeLState acts as a comprehensive container for your EasyDeL model, including training progress, model parameters, and optimizer information.

from easydel import EasyDeLState

state = EasyDeLState.from_pretrained(
    pretrained_model_name_or_path="model_name",
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    sharding_axis_dims=(1, -1, 1, 1)
)

Training Examples

Supervised Fine-Tuning

from easydel import SFTTrainer, TrainArguments

trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    formatting_func=prompter,
    packing=True,
    num_of_sequences=max_length,
)

output = trainer.train(flax.core.FrozenDict({"params": params}))

DPO Fine-tuning

from easydel import DPOTrainer

dpo_trainer = DPOTrainer(
    model_state=state,
    ref_model_state=ref_state,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    arguments=arguments,
    max_length=max_length,
    max_target_length=max_target_length,
    max_prompt_length=max_prompt_length,
)

output = dpo_trainer.train()

Contributing

Contributions to EasyDeL are welcome! Please fork the repository, make your changes, and submit a pull request.

License 📜

EasyDeL is released under the Apache v2 license. See the LICENSE file for more details.

Contact

If you have any questions or comments about EasyDeL, you can reach out to me at erfanzare810@gmail.com.

Citation

To cite EasyDeL in your work:

@misc{Zare Chavoshi_2023,
    title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
    url={https://github.com/erfanzar/EasyDeL},
    author={Zare Chavoshi, Erfan},
    year={2023}
}