Open monorimet opened 4 months ago
There is a mechanism that already exists for this and it's plumbed through everything compiler to runtime. In fact, the python API already uses it on some situations to provide enhanced type casting, etc, if a json dict is present on the function with more information.
It might take some archaeology to figure out how to populate it in the newer torch path: it was created and used for TF, which had a lot of "sidecar" data needed for interop.
Let me see if I can find a starting point.
Here's an example where we are mocking such a reflection dict: https://github.com/iree-org/iree/blob/main/runtime/bindings/python/tests/function_test.py
And here is where it is obtained in the python API: https://github.com/iree-org/iree/blob/7b58c712a1c6bc1a13fc4525ef07b0030a950d86/runtime/bindings/python/iree/runtime/function.py#L150
I can't find an in tree example that is setting the attribute on the compiler side that populates something like that. Most of that was old tf interop stuff that is no longer among the living. But the mechanism still works.
Just a sec.
Since this is basically an unused code path now, just noting some tests that verify it:
Actually, that's the best I can find on my phone. A git grep for iree.reflection will show more. That's the function attribute that is preserved to the runtime.
Motivation
In turbine-models we have single export functions for different huggingface models -- each of these exports (and their respective CompiledModule instances) are parametrized by things like I/O shapes, inlined/external weights, data types, and potentially features like embeddings that change the IR for a specific export.
Currently most of the "config control" we do for these exported artifacts relies on careful filename conventions, which works in current state, but puts all responsibility of validation on file paths which are limited on platforms like Windows.
If we want to keep using AOT compiled modules as we do for Stable Diffusion, it will be necessary to have a more reliable solution to ensuring configuration control. Additionally, if we have a good enough solution for metadata assignment, we gain the ability to infer inference methods through an ordered collection of compiled modules.
Use cases
A potential solution to this problem can be expressed through the following example:
Given an exported UNetCondition2d model, assuming we are using externalized weights, we could provide in one MLIR module multiple methods for certain configurations that change the IR but are usually expected to seamlessly toggle from the user's point of view.
i.e., if a user wants to run the "default" txt2img inference for stable diffusion, given a specific output shape, we export a vmfb with a single public function (
run_forward
or similar) that mirrors the forward call of the torch.nn.module.Now, what if they want to use a LoRA embedding? Export another instance of the module, add "lora" to the filename this time, unload the old .vmfb, load the new one, and then they're good to go. This process in current state can take anywhere from 10 seconds for small models to several minutes, depending on state of compilation for some arbitrary backend, load time for large parameter indexes, etc. In a perfect world we could do this all JIT and the whole process would happen in a flash, but I think we can do better by prefetching a few different inference modes ahead of time, and providing the user-facing application the ability to call a different public function that matches the input configuration.
If we embed metadata into each function as a reflection attribute, we can scan each public function of an exported module for a matching configuration without remapping parameters. Then the flow would look like this from the turbine API point of view:
Implementation
The concept of exporting multiple inference modes in a single CompiledModule doesn't need extra work. The proposition of including metadata in a reflection attribute simply allows for the implementation of this in a more reliable, abstracted form that can be managed in base classes. We already include torch argument schemas as function attributes through CompiledModule exports. I wonder if these schemas are designed for a similar purpose, or if they could be included in the metadata reflection along with more information from the CompiledModule instance?