TorchMoE / MoE-Infinity

PyTorch library for cost-effective, fast and easy serving of MoE models.
Apache License 2.0
88 stars 5 forks source link

MoE-Infinity API Proposal #2

Closed drunkcoding closed 5 months ago

drunkcoding commented 7 months ago

Description

We propose a class MoE as the entry point. It loads a (potentially sharded) checkpoint inside a model, sending weights to a given device as they are loaded and adds the various hooks that will make this model run properly (even if split across devices).

The class has an additional generate member function to overwrite the default generate and adds tracing capability. It has the same behaviour as HuggingFace model.generate.

class MoE:
  def __init__(self, model_name_or_path: Union[str, os.PathLike], config: Union[str, os.PathLike] = None) -> None:
    """
    Args:
        model_name_or_path (`str` or `os.PathLike`): The model to load. It can be:
            - a name of HuggingFace Transformers model
            - a path to a file containing a whole model state dict
            - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
        config (`Dict` or `os.PathLike`): The MoE-Infinity configuration. It can be:
            - a Python dictionary containing the configuration
            - a path to a JSON file containing the configuration
    """
    pass

  def generate(self, input_ids: torch.LongTensor, **kwargs) -> Any:
    """  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation. If `past` is used, only `bos_token_id` is used as
            prompt.
        **kwargs: Additional arguments for the generation method. Check the HuggingFace documentation of the model's
            `generate` method for the supported arguments.

    Returns:
        `torch.LongTensor` of shape `(batch_size, sequence_length)`:
            The generated sequences. Sequences shorter than `min_length` are padded with `pad_token_id`.
    """
    pass

Usage examples

import torch
import os
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
from moe_infinity import MoE

user_home = os.path.expanduser('~')

checkpoint = 'mistralai/Mixtral-8x7B-Instruct-v0.1'

# specifies the path on disk to offload parameters
config = {
    "offload_path": os.path.join(user_home, "moe-infinity"),
}

model = MoE(checkpoint, config) # one line change to support offloading

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output_ids = model.generate(input_ids)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(output_text)
luomai commented 7 months ago

Looks quite good to me. Thanks.

For the class name, I am thinking of if we could name it LLM? The current name means we only support MoE.

Is there a use case for our project to be used for GPT-like models (simply using our better system implementation for offloading)?