pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

Expensive PyTorch to StableHLO conversion #8331

Open Wheest opened 1 month ago

Wheest commented 1 month ago

I've been using the StableHLO conversion, and have found that it is much more expensive than I expected.

Relatively "small" LLMs take a few hours, and for a larger model, meta-llama/Llama-3.1-405B was converting for multiple days, using 100% CPU across 15 threads, and the conversion still didn't complete. I'm unsure how long I would have to leave it for it work, if throwing more CPU at it would help, or if I've found an edge case that caused it to hang.

In this case, here is the conversion script I've been using:

import torch
import torch.export._trace
from transformers import AutoTokenizer, LlamaForCausalLM
from torch_xla.stablehlo import exported_program_to_stablehlo, StableHLOExportOptions
import argparse
import logging

# Set up logging to file and console
def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(message)s",
        handlers=[logging.FileHandler("export_log.log"), logging.StreamHandler()],
    )

def main(model_name):
    logging.info(f"Starting model export for {model_name}")

    # Load tokenizer and model
    cache_dir = "/models/"
    name = model_name.split("/")[-1]

    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
    model = LlamaForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)

    base_prompt = "How many hours are in a day?"
    base_inputs = tokenizer(base_prompt, return_tensors="pt")
    input_ids = base_inputs.input_ids

    logging.info("Exporting model using torch export...")
    exported = torch.export._trace._export(
        model,
        (input_ids,),
        strict=False,
        _allow_complex_guards_as_runtime_asserts=True,
    )

    del model  # Free memory

    logging.info("Exporting model to StableHLO...")
    stablehlo_program = exported_program_to_stablehlo(
        exported,
        options=StableHLOExportOptions(save_weights=False, export_weights=False),
    )

    output_file = f"examples/stablehlo_{name}.mlir"
    with open(output_file, "w") as f:
        f.write(stablehlo_program.get_stablehlo_text("forward"))

    logging.info(f"StableHLO exported successfully to {output_file}")

if __name__ == "__main__":
    setup_logging()
    parser = argparse.ArgumentParser(
        description="Generate StableHLO from Hugging Face model"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="meta-llama/Meta-Llama-3-70B",
        choices=[
            "meta-llama/Meta-Llama-3-70B",
            "meta-llama/Llama-2-7b-chat-hf",
            "meta-llama/Llama-3.1-405B",
        ],
        help="The model to generate StableHLO from",
    )
    args = parser.parse_args()
    main(args.model)

Note that I'm using StableHLOExportOptions to avoid saving the weights, since right now I just care about the model structure in StableHLO. Even with that, I require over 1.5TB of memory Llama-3.1-405B. I'm using SSD swap space to augment my main memory.

Any ideas on how I can make my conversion more efficient in this case? Are there serious bottlenecks in my system that could be improved. I know that swap is slower than regular memory, but if I'm not converting the weights then surely this is less relevant?