Open mariosasko opened 1 year ago
This has come a few times so aggregating how this works today and our plans for next steps in the future
Thanks for the quick response! I did not mention this in the issue description, but the ability to pickle the compiled functions would also be great (we use dill
, which can also pickle functions). Considering it's already possible to fetch the original function/method of a compiled function/model, the simplest solution that would work for us is exposing the params passed to torch.compile
(e.g., as an attribute of the compile context). Then, we could define a simple reduction function to make the pickling possible.
(If I'm not mistaken, only some params can be fetched/inferred currently (e.g, disable
))
@mariosasko just wanna make sure I understand, as a first step it sounds like you're mostly interested in knowing exactly which args were using when compiling a model for reproducibility. If so I was also planning on just putting that in the nn module state dict
I would really love to just be able to dill or pickle an entire optimized module but there's way too many setattr to make that possible easily but I'll still dig through it to see what's possible
Any updates on this, or has anyone found a workaround?
cc: @msaroufim
So the simplest workaround is to save the state dict and not the model which we mentioned back when 2.0 was released https://pytorch.org/get-started/pytorch-2.0/#serialization
I tried to get saving the model to work directly here https://github.com/pytorch/pytorch/pull/101651 and it did work you could effectively save compiled models directly but the problem was my changes introduced some extra graph breaks across the board which have a performance impact, I couldn't figure it out and I don't have bandwidth to further inspect but if someone would like to revisit I'd be happy to review and merge
My workaround is to "repair" checkpoints that contain the undesired "_orig_mod."
prefix.
Save the following script:
import sys
import torch
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix) :]
return text
def repair_checkpoint(path):
ckpt = torch.load(path)
in_state_dict = ckpt["model_state_dict"]
pairings = [
(src_key, remove_prefix(src_key, "_orig_mod."))
for src_key in in_state_dict.keys()
]
if all(src_key == dest_key for src_key, dest_key in pairings):
return # Do not write checkpoint if no need to repair!
out_state_dict = {}
for src_key, dest_key in pairings:
print(f"{src_key} ==> {dest_key}")
out_state_dict[dest_key] = in_state_dict[src_key]
ckpt["model_state_dict"] = out_state_dict
torch.save(ckpt, path)
if __name__ == "__main__":
paths = sys.argv[1:]
for path in paths:
print(path)
repair_checkpoint(path)
print("========")
Then:
python checkpoint_unwrap_orig_model.py **/*.pth
NOTE: In my checkpoints, the state_dict is actually inside ckpt["model_state_dict"]
. If yours is in a different place, adjust that as necessary, e.g. ckpt
if your state_dict is exactly the root of the checkpoint.
@msaroufim I tried your workaround (torch.compile my model, then save the state_dict, then load a new non-compiled version of my model, finally insert the saved state_dict) and the model is slow (non-compiled). Am I doing something wrong? Or are you saying that I have to recompile regardless of whether I load a compiled state_dict or not?
My workaround is to "repair" checkpoints that contain the undesired
"_orig_mod."
prefix.Save the following script:
import sys import torch def remove_prefix(text, prefix): if text.startswith(prefix): return text[len(prefix) :] return text def repair_checkpoint(path): ckpt = torch.load(path) in_state_dict = ckpt["model_state_dict"] pairings = [ (src_key, remove_prefix(src_key, "_orig_mod.")) for src_key in in_state_dict.keys() ] if all(src_key == dest_key for src_key, dest_key in pairings): return # Do not write checkpoint if no need to repair! out_state_dict = {} for src_key, dest_key in pairings: print(f"{src_key} ==> {dest_key}") out_state_dict[dest_key] = in_state_dict[src_key] ckpt["model_state_dict"] = out_state_dict torch.save(ckpt, path) if __name__ == "__main__": paths = sys.argv[1:] for path in paths: print(path) repair_checkpoint(path) print("========")
Then:
python checkpoint_unwrap_orig_model.py **/*.pth
NOTE: In my checkpoints, the state_dict is actually inside
ckpt["model_state_dict"]
. If yours is in a different place, adjust that as necessary, e.g.ckpt
if your state_dict is exactly the root of the checkpoint.
If possible (i.e. because you can change the code that is generating the checkpoints), you can save model._orig_mod.state_dict()
instead of model.state_dict()
for compiled models. This will avoid the _orig_mod.
prefix everywhere, and therefore not require any fixing on load. You just need to load_state_dict()
and torch.compile()
after loading.
Generic code that doesn't know whether a model is compiled or not could do something like:
getattr(model, '_orig_mod', model).state_dict()
This is a duplicate of https://github.com/pytorch/pytorch/issues/93470
I usually override the state_dict() method and load_state_dict() method, using a specific structure to solve most problems, for example:
from collections import OrderedDict
import torch.nn as nn
class CustomNet(nn.Module):
def __init__(self):
self.kernel: nn.Module = ...
# override state_dict() method
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = OrderedDict()
prefix = "" # remove prefix
destination.update([('kernel', self.kernel.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)])
return destination
# override load_state_dict()
def load_state_dict(self, state_dict, ...):
self.kernel.load_state_dict(state_dict)
This structure can fix '_orig_mod' prefix. However, the above code is just an example of the method and cannot be run. You should adjust the structure accordingly according to your own code.
@msaroufim any updates?
Nope unsassigning myself for now since i haven't had time to keep fixing issues here
Isn't torch.export.export
kind of similar to what is requested here? https://pytorch.org/docs/stable/export.html
@fxmarty torch.export.export
only compiles a single path from the program control flow, as far as I understand.
It would be great to be able to save a model that has run torch.compile
so that we do not need to re-compile each time we launch a program! +1 for this
@rfeinman Could you clarify what you mean by "torch.export.export only compiles a single path from the program control flow"? Export should be able to handle control flow if it is rewritten using torch.cond
.
@angelayi my understanding is based off of the explanation here: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#comparison-to-torchscript-and-fx-tracing
Primarily, the advantage of torch.compile lies in its ability to handle arbitrary Python code with minimal changes to existing code.
One case that torch.compile can handle that other compiler solutions struggle with is data-dependent control flow (the if x.sum() < 0: line below).
TorchScript tracing f1 results in silently incorrect results, since only the actual control flow path is traced.
I think the ability of torch.compile
to handle arbitrary Python code with minimal changes is a very nice feature, and it would be great if this feature could transfer to serialization (i.e., if we don't have to swap in torch.cond
, etc).
@rfeinman that makes sense! torch.export wants to get a full graph representation of the code so it requires these code rewrites, instead of defaulting to the python code, which is what torch.compile does.
Maybe a single GraphModule is able to represent torch.cond then? I was not aware of that
edit: yep:
import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
res = dyn_shape_mod(torch.randn(10, 10))
##
from torch.export import export
example_args = (torch.randn(10, 10),)
exported_program = export(
DynamicShapeCondPredicate(), args=example_args
)
print(exported_program)
print(exported_program.graph)
graph():
%l_x_ : [num_users=1] = placeholder[target=l_x_]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (True, %true_graph_0, %false_graph_0, [%l_x_]), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%conditional, 0), kwargs = {})
return (getitem,)
now when it comes to which hardware/compiler is able to consume that...
🐛 Describe the bug
Serializing a compiled model with
pickle
fails withCan't pickle local object 'convert_frame.<locals>._convert_frame'
andcannot pickle 'ConfigModuleInstance' object
when usingdill
.A Colab with an example: https://colab.research.google.com/drive/1v6jUUq86ql1Era4X47cIDj7bzrrz2RZe?usp=sharing
In Hugging Face Datasets, this error stops us from generating (deterministic) hashes for transforms (functions) that reference a compiled model, meaning such transforms cannot be cached and must be re-computed each time when transforming a dataset.
(The "export" API for the compiled models would also work for us.)
Error logs
No response
Minified repro
No response
Versions
Colab env with torch 2.0.1 installed
``` PyTorch version: 2.0.1+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: 10.0.0-4ubuntu1 CMake version: version 3.25.2 Libc version: glibc-2.31 Python version: 3.10.11 (main, Apr 5 2023, 14:15:10) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-5.10.147+-x86_64-with-glibc2.31 Is CUDA available: False CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 2 On-line CPU(s) list: 0,1 Thread(s) per core: 2 Core(s) per socket: 1 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 79 Model name: Intel(R) Xeon(R) CPU @ 2.20GHz Stepping: 0 CPU MHz: 2200.196 BogoMIPS: 4400.39 Hypervisor vendor: KVM Virtualization type: full L1d cache: 32 KiB L1i cache: 32 KiB L2 cache: 256 KiB L3 cache: 55 MiB NUMA node0 CPU(s): 0,1 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable; SMT Host state unknown Vulnerability Meltdown: Vulnerable Vulnerability Mmio stale data: Vulnerable Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Vulnerable Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities Versions of relevant libraries: [pip3] numpy==1.22.4 [pip3] torch==2.0.1+cu118 [pip3] torchaudio==2.0.2+cu118 [pip3] torchdata==0.6.0 [pip3] torchsummary==1.5.1 [pip3] torchtext==0.15.1 [pip3] torchvision==0.15.2+cu118 [pip3] triton==2.0.0 [conda] Could not collect ```cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @soumith @wconstab @ngimel