Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

Recipes and high-level entrypoint #1188

Open lantiga opened 1 month ago

lantiga commented 1 month ago

🚀 Feature

Thunder recipes and new high-level entrypoint.

This is important

Motivation

Providing a model to thunder.jit requires understanding on:

The above is specific to models and cluster configurations, so it would be good to have a way to package everything up in a reusable class, that applies to certain models or model families and can be shipped alongside the original code.

For instance, one could have a HFLllama3 recipe, or a more general HFLlama recipe that can be applied to all variants of Llama. A recipe could expose options for different configurations as well, like the use of distributed. Last, one could have a HFLlama3Hopper recipe that optimizes the combination of executors for a certain architecture.

In a nutshell, the recipe would orchestrate what is needed to make a thunder run on a model and what gets applied to that model. Code that use recipes would not:

One of the uses for a recipe is when dealing details of implementations and the way we deal with them, e.g.

The possible introduction of a ThunderFX entrypoint makes this even more attractive. The recipe could decide to call into thunder.jit or ThunderFX according to the recipe implementor decides to go with.

Pitch

This is how the new entrypoint and a recipe could look like. NOTE: the naming is up for grabs, this is just for demonstration purposes.

Here's a skeleton of a base ThunderRecipe class and the thunder.compile entrypoint

import thunder
from thunder import jit, Transform, Executor

class Lookaside:
    def __init__(self, fn, replace_with):
        self._fn = fn
        self._replace_with = replace_with

class ThunderRecipe:
    def __init__(self):
        pass

    def validate(self, model):
        # this is supposed to raise
        pass

    # def setup_operators(self) -> list[Operator]:
    #     # this is for registering custom kernels on the fly
    #     return None

    def setup_lookasides(self) -> list[Lookaside]:
        return None

    def setup_transforms(self) -> list[Transform]:
        return None

    def setup_executors(self) -> list[Executor]:
        return None

    def setup_config(self):
        return {}

    def setup(self, model):
        self.validate(model)

        lookasides = self.setup_lookasides()

        if lookasides is not None:
            for lookaside in lookasides:
                thunder.jit_ext.register_general_jit_lookaside(lookaside._fn)(
                    thunder.jit_ext.interpreter_needs_wrap()(
                        lookaside._replace_with
                    )
                )

        self.lookasides = lookasides
        self.executors = self.setup_executors()
        self.transforms = self.setup_transforms()

from typing import List, Sequence

def compile(model, recipe: None | ThunderRecipe | List[ThunderRecipe] = None):
    recipes = recipe if isinstance(recipe, Sequence) else [recipe]

    transforms = []
    executors = []
    config = {}

    for r in recipes:
        r.setup(model)
        transforms.extend(r.transforms)
        executors.extend(r.executors)
        config.update(r._config)

    jmodel = jit(model,
                 transforms=transforms,
                 executors=executors,
                 **recipe._config)

    return jmodel

Example recipe for HFBert

class CompileHFBert(ThunderRecipe):
    def __init__(self):
        super().__init__()

    def validate(self, model):
        if not isinstance(model, transformers.BertForSequenceClassification):
            raise ValueError("The model must be a BertForSequenceClassification")

    def setup_lookasides(self):
        warn_lookaside = Lookaside(
            fn=transformers.modeling_utils.PreTrainedModel.warn_if_padding_and_no_attention_mask,
            replace_with=lambda *args: None
        )

        if hasattr(torch, "compiler") and hasattr(torch.compiler, "is_compiling"):
            is_compiling = torch.compiler.is_compiling
        else:
            is_compiling = torch._dynamo.is_compiling

        is_compiling_lookaside = Lookaside(
            fn=is_compiling,
            replace_with=lambda *args: True
        )

        return [warn_lookaside, is_compiling_lookaside]

    def setup_transforms(self):
        return None

    def setup_executors(self):
        return None
bert = transformers.BertForSequenceClassification(transformers.BertConfig())

# the default should work
t_bert = thunder.compile(bert)

# this should work as well
t_bert = thunder.compile(bert, recipe=CompileHFBert())

One could think about composing recipes. A basic quantization recipe:

class Quantize4Bit(ThunderRecipe):
    def __init__(self):
        super().__init__()

    def setup_transforms(self):
        from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit
        return [BitsAndBytesLinearQuant4bit()]

    def setup_executors(self):
        from thunder.transforms.quantization import get_bitsandbytes_executor
        return [get_bitsandbytes_executor()]

and a composed recipe with configurable quantization

class QuantizedHFBert(ThunderRecipe):
    def __init__(self, quantize=False):
        super().__init__()
        self.bert_recipe = CompileHFBert()
        if quantize:
            self.quant_recipe = Quantize4Bit()
        self.quantize = quantize

    def setup_lookasides(self):
        return self.bert_recipe.setup_lookasides()

    def setup_transforms(self):
        if self.quantize:
            return self.quant_recipe.setup_transforms()
        return None

    def setup_executors(self):
        if self.quantize:
            return self.quant_recipe.setup_executors()
        return None
t_bert = thunder.compile(bert, recipe=QuantizedHFBert(quantize=True))

One could also think about specifying lists of recipes but I'm on the fence about it at least initially. We could have rich rules on how to compose, but doing so manually like above is probably better while we're getting a sense for the system.

t_bert = thunder.compile(bert, recipe=[CompileHFBert(), Quantize4Bit()])
t-vi commented 1 month ago

I like the proposal in general, a couple of details:

apaz-cli commented 1 month ago

Overall, I like it. I do think that we need to be able to bundle together all the nonsense we do to a model. Although I think that ideally recipes should combine-able and composable? Or at least that was the original goal. Quantizing and then distributing vs distributing and then quantizing should be equivalent, and likewise for the grad transform, etc. If that makes semantic sense. You tell me, I suppose. On the surface it seems like we're trading the issue of non-composable transforms for the issue of non-composable recipes.

I like exposing Lookaside as a class. I think that de-mystifies things quite a bit. It makes it discoverable in the docs, and passing them along to thunder.jit() makes a lot of sense. I think it's a core functionality of Thunder, and I don't think it makes sense that it's so buried. Also, like we talked about, I would refer to move stuff out of global context variables. I think this is a good change regardless.

I really like setup_operators().

What does setup_config() do? I understand all the other methods.

I take it that setup() is not supposed to be overridden? I forget if there's an annotation for that. If not, it might be good to make that explicit. I suppose if somebody wants to do arbitrary things when creating it they can subclass and put it in the __init__ function? And then when they apply it, they can put arbitrary stuff in setup_operators()?

As long as it actually improves (or at least doesn't worsen) the composability problem, looks dope. I think we've had the problem of the thunder API being scattered across a lot of different methods and packages for a long time now, and I like that this centralizes it. It will make it a lot easier to understand for newcomers.

lantiga commented 1 month ago

Thanks for the comments!

Regarding composable recipes: I have thoughts about adding some sort of traits to either recipes (or transforms themselves) that would inform how to compose things together (as in: I need to come after A and B kind of things, and before E and F kind of things).

My ideal sequence would be:

  1. put the burden on the recipe developer to do what I did above
  2. support lists of recipes to be applied in the same order
  3. traits

As far as I'm concerned we can stay with 1 for as long as we need to really figure out how recipes will compose, but we know we are not cornered there.

Maybe to make developer experience nicer we could also add a way to quickly add an extra transform or executor to a recipe inline, without subclassing.

setup_config is for those flags one wants to pass to thunder.jit (or ThunderFX depending on the recipe). Not sure it's the best way to do that, we'll see.

Correct, setup() is not to be overridden, unless you really want to. Maybe we can just rename it to _setup() or something that makes it clearer that you shouldn't mess with it.

crcrpar commented 1 month ago

Overall sounds great to me. I have some questions and comments to help myself understand this proposal better.

Q1 -- setup_operators: Would it even let us register custom executor, like apex executor we already have? Q2 -- setup_<foo>: Would they really need to be a member method, instead of staticmethod?

lantiga commented 1 month ago

Thank you @crcrpar

Q1: yes in the case where we want to offer a straightforward way to add a single-operator executor. For anything more complex we should require to define a proper executor.

Q2: the setup methods could rely on properties that we set on the recipe. Like for example, I could have a recipe that takes a use_fsdp=True|False argument in the constructor, and you would want to access self.use_fsdp in setup_executors. So I think they need to be members.