pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.72k stars 22.28k forks source link

Make compiled models serializable #101107

Open mariosasko opened 1 year ago

mariosasko commented 1 year ago

🐛 Describe the bug

Serializing a compiled model with pickle fails with Can't pickle local object 'convert_frame.<locals>._convert_frame' and cannot pickle 'ConfigModuleInstance' object when using dill.

A Colab with an example: https://colab.research.google.com/drive/1v6jUUq86ql1Era4X47cIDj7bzrrz2RZe?usp=sharing

In Hugging Face Datasets, this error stops us from generating (deterministic) hashes for transforms (functions) that reference a compiled model, meaning such transforms cannot be cached and must be re-computed each time when transforming a dataset.

(The "export" API for the compiled models would also work for us.)

Error logs

No response

Minified repro

No response

Versions

Colab env with torch 2.0.1 installed ``` PyTorch version: 2.0.1+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: 10.0.0-4ubuntu1 CMake version: version 3.25.2 Libc version: glibc-2.31 Python version: 3.10.11 (main, Apr 5 2023, 14:15:10) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-5.10.147+-x86_64-with-glibc2.31 Is CUDA available: False CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 2 On-line CPU(s) list: 0,1 Thread(s) per core: 2 Core(s) per socket: 1 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 79 Model name: Intel(R) Xeon(R) CPU @ 2.20GHz Stepping: 0 CPU MHz: 2200.196 BogoMIPS: 4400.39 Hypervisor vendor: KVM Virtualization type: full L1d cache: 32 KiB L1i cache: 32 KiB L2 cache: 256 KiB L3 cache: 55 MiB NUMA node0 CPU(s): 0,1 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable; SMT Host state unknown Vulnerability Meltdown: Vulnerable Vulnerability Mmio stale data: Vulnerable Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Vulnerable Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities Versions of relevant libraries: [pip3] numpy==1.22.4 [pip3] torch==2.0.1+cu118 [pip3] torchaudio==2.0.2+cu118 [pip3] torchdata==0.6.0 [pip3] torchsummary==1.5.1 [pip3] torchtext==0.15.1 [pip3] torchvision==0.15.2+cu118 [pip3] triton==2.0.0 [conda] Could not collect ```

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @soumith @wconstab @ngimel

msaroufim commented 1 year ago

This has come a few times so aggregating how this works today and our plans for next steps in the future

  1. Indeed you can't pickle an optimized module but you can pickle the original module because the weights are shared https://pytorch.org/get-started/pytorch-2.0/#serialization - I'm planning on adding a simple get and set state here to unwrap the original module automatically for people
  2. Because 1 is annoying we recently introduced an in place module compilation API that would make saving and loading work https://github.com/pytorch/pytorch/pull/97565
  3. To improve reproducibility we're also thinking of saving all the config arguments that were passed to torch.compile and persist them when you save and load a model
  4. But unfortunately 2 doesn't solve the problem of having to recompile a model when you load it so cold starts for inference are bad, I'll have a POC working to solve this very soon but the core idea is to dump the entire inductor (note that the inductor cache includes a triton, inductor and soon an fx cache) cache into a state dict and reload it later https://github.com/msaroufim/mlsys-experiments/blob/main/compile-checkpoint/save-hook.py
  5. Conceptually 4 should work if you assume that the same machine type will be used for inference and training, torch.load and save have a contract in that they guarantee working across devices and this might not be true for us, so instead maybe we just write some docs to recommend users copy their caches to some networked file system?
  6. Export obviously is a good solution but so far no date for an official release, export is also focused on environments without python available but if python is available I think 3 will work just fine - EDIT: export is now available
  7. It might be possible to pickle/dill the entire compiled module but I haven't figured out how yet since there's lots of dynamic behavior but dill at least is powerful enough to pickle a python interpreter so i feel like it should work. One thing we can do is when trying to pickle the optimized module, we automatically unwrap and pickle the unoptimized one - EDIT: I got stuck working on this because of extra graph breaks but I'd be happy to help merge if someone wants to pick this up https://github.com/pytorch/pytorch/pull/101651
mariosasko commented 1 year ago

Thanks for the quick response! I did not mention this in the issue description, but the ability to pickle the compiled functions would also be great (we use dill, which can also pickle functions). Considering it's already possible to fetch the original function/method of a compiled function/model, the simplest solution that would work for us is exposing the params passed to torch.compile (e.g., as an attribute of the compile context). Then, we could define a simple reduction function to make the pickling possible.

(If I'm not mistaken, only some params can be fetched/inferred currently (e.g, disable))

msaroufim commented 1 year ago

@mariosasko just wanna make sure I understand, as a first step it sounds like you're mostly interested in knowing exactly which args were using when compiling a model for reproducibility. If so I was also planning on just putting that in the nn module state dict

I would really love to just be able to dill or pickle an entire optimized module but there's way too many setattr to make that possible easily but I'll still dig through it to see what's possible

varunshenoy commented 1 year ago

Any updates on this, or has anyone found a workaround?

cc: @msaroufim

msaroufim commented 1 year ago

So the simplest workaround is to save the state dict and not the model which we mentioned back when 2.0 was released https://pytorch.org/get-started/pytorch-2.0/#serialization

I tried to get saving the model to work directly here https://github.com/pytorch/pytorch/pull/101651 and it did work you could effectively save compiled models directly but the problem was my changes introduced some extra graph breaks across the board which have a performance impact, I couldn't figure it out and I don't have bandwidth to further inspect but if someone would like to revisit I'd be happy to review and merge

YodaEmbedding commented 11 months ago

My workaround is to "repair" checkpoints that contain the undesired "_orig_mod." prefix.

Save the following script:

import sys
import torch

def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text

def repair_checkpoint(path):
    ckpt = torch.load(path)
    in_state_dict = ckpt["model_state_dict"]
    pairings = [
        (src_key, remove_prefix(src_key, "_orig_mod."))
        for src_key in in_state_dict.keys()
    ]
    if all(src_key == dest_key for src_key, dest_key in pairings):
        return  # Do not write checkpoint if no need to repair!
    out_state_dict = {}
    for src_key, dest_key in pairings:
        print(f"{src_key}  ==>  {dest_key}")
        out_state_dict[dest_key] = in_state_dict[src_key]
    ckpt["model_state_dict"] = out_state_dict
    torch.save(ckpt, path)

if __name__ == "__main__":
    paths = sys.argv[1:]
    for path in paths:
        print(path)
        repair_checkpoint(path)
        print("========")

Then:

python checkpoint_unwrap_orig_model.py **/*.pth

NOTE: In my checkpoints, the state_dict is actually inside ckpt["model_state_dict"]. If yours is in a different place, adjust that as necessary, e.g. ckpt if your state_dict is exactly the root of the checkpoint.

wilson97 commented 11 months ago

@msaroufim I tried your workaround (torch.compile my model, then save the state_dict, then load a new non-compiled version of my model, finally insert the saved state_dict) and the model is slow (non-compiled). Am I doing something wrong? Or are you saying that I have to recompile regardless of whether I load a compiled state_dict or not?

pallgeuer commented 9 months ago

My workaround is to "repair" checkpoints that contain the undesired "_orig_mod." prefix.

Save the following script:

import sys
import torch

def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text

def repair_checkpoint(path):
    ckpt = torch.load(path)
    in_state_dict = ckpt["model_state_dict"]
    pairings = [
        (src_key, remove_prefix(src_key, "_orig_mod."))
        for src_key in in_state_dict.keys()
    ]
    if all(src_key == dest_key for src_key, dest_key in pairings):
        return  # Do not write checkpoint if no need to repair!
    out_state_dict = {}
    for src_key, dest_key in pairings:
        print(f"{src_key}  ==>  {dest_key}")
        out_state_dict[dest_key] = in_state_dict[src_key]
    ckpt["model_state_dict"] = out_state_dict
    torch.save(ckpt, path)

if __name__ == "__main__":
    paths = sys.argv[1:]
    for path in paths:
        print(path)
        repair_checkpoint(path)
        print("========")

Then:

python checkpoint_unwrap_orig_model.py **/*.pth

NOTE: In my checkpoints, the state_dict is actually inside ckpt["model_state_dict"]. If yours is in a different place, adjust that as necessary, e.g. ckpt if your state_dict is exactly the root of the checkpoint.

If possible (i.e. because you can change the code that is generating the checkpoints), you can save model._orig_mod.state_dict() instead of model.state_dict() for compiled models. This will avoid the _orig_mod. prefix everywhere, and therefore not require any fixing on load. You just need to load_state_dict() and torch.compile() after loading.

Generic code that doesn't know whether a model is compiled or not could do something like: getattr(model, '_orig_mod', model).state_dict()

fxmarty commented 7 months ago

This is a duplicate of https://github.com/pytorch/pytorch/issues/93470

ecstayalive commented 7 months ago

I usually override the state_dict() method and load_state_dict() method, using a specific structure to solve most problems, for example:

from collections import OrderedDict
import torch.nn as nn

class CustomNet(nn.Module):
    def __init__(self):
        self.kernel: nn.Module = ...
    # override state_dict() method
    def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
        if destination is None:
            destination = OrderedDict()
        prefix = ""  # remove prefix
        destination.update([('kernel', self.kernel.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)])
        return destination
    # override load_state_dict()
    def load_state_dict(self, state_dict, ...):
        self.kernel.load_state_dict(state_dict)

This structure can fix '_orig_mod' prefix. However, the above code is just an example of the method and cannot be run. You should adjust the structure accordingly according to your own code.

tugsbayasgalan commented 6 months ago

@msaroufim any updates?

msaroufim commented 6 months ago

Nope unsassigning myself for now since i haven't had time to keep fixing issues here

fxmarty commented 6 months ago

Isn't torch.export.export kind of similar to what is requested here? https://pytorch.org/docs/stable/export.html

rfeinman commented 4 months ago

@fxmarty torch.export.export only compiles a single path from the program control flow, as far as I understand.

It would be great to be able to save a model that has run torch.compile so that we do not need to re-compile each time we launch a program! +1 for this

angelayi commented 4 months ago

@rfeinman Could you clarify what you mean by "torch.export.export only compiles a single path from the program control flow"? Export should be able to handle control flow if it is rewritten using torch.cond.

rfeinman commented 4 months ago

@angelayi my understanding is based off of the explanation here: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#comparison-to-torchscript-and-fx-tracing

Primarily, the advantage of torch.compile lies in its ability to handle arbitrary Python code with minimal changes to existing code.

One case that torch.compile can handle that other compiler solutions struggle with is data-dependent control flow (the if x.sum() < 0: line below).

TorchScript tracing f1 results in silently incorrect results, since only the actual control flow path is traced.

I think the ability of torch.compile to handle arbitrary Python code with minimal changes is a very nice feature, and it would be great if this feature could transfer to serialization (i.e., if we don't have to swap in torch.cond, etc).

angelayi commented 4 months ago

@rfeinman that makes sense! torch.export wants to get a full graph representation of the code so it requires these code rewrites, instead of defaulting to the python code, which is what torch.compile does.

fxmarty commented 4 months ago

Maybe a single GraphModule is able to represent torch.cond then? I was not aware of that

edit: yep:

import torch

def true_fn(x: torch.Tensor):
    return x.cos() + x.sin()

def false_fn(x: torch.Tensor):
    return x.sin()

class DynamicShapeCondPredicate(torch.nn.Module):
    """
    A basic usage of cond based on dynamic shape predicate.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        def true_fn(x: torch.Tensor):
            return x.cos()

        def false_fn(x: torch.Tensor):
            return x.sin()

        return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

dyn_shape_mod = DynamicShapeCondPredicate()

res = dyn_shape_mod(torch.randn(10, 10))

##

from torch.export import export

example_args = (torch.randn(10, 10),)

exported_program = export(
    DynamicShapeCondPredicate(), args=example_args
)
print(exported_program)
print(exported_program.graph)
graph():
    %l_x_ : [num_users=1] = placeholder[target=l_x_]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (True, %true_graph_0, %false_graph_0, [%l_x_]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%conditional, 0), kwargs = {})
    return (getitem,)

now when it comes to which hardware/compiler is able to consume that...