pytorch / multipy

torch::deploy (multipy for non-torch uses) is a system that lets you get around the GIL problem by running multiple Python interpreters in a single C++ process.
Other
174 stars 35 forks source link

package issues with functions under C extensions #44

Open d4l3k opened 2 years ago

d4l3k commented 2 years ago
import torch

from torch.package import PackageExporter, PackageImporter

output_path = "/tmp/model.pt"

def save_load(model):
    with PackageExporter(output_path) as e:
        e.extern("torch.**")
        e.intern("**")

        e.save_pickle("model", "model.pkl", model)

    imp = PackageImporter(output_path)
    return imp.load_pickle("model", "model.pkl")

    print("pass")

model = torch.nn.TransformerEncoderLayer(
        d_model=64,
        nhead=2,
    dim_feedforward=64,
    dropout=1.0,
    batch_first=True,
    activation='gelu',
    norm_first=True,
)
save_load(model)

The issue is that F.gelu can't be loaded from package due to a nimport error

ModuleNotFoundError: No module named 'torch._C._nn'; 'torch._C' is not a package
d4l3k commented 2 years ago

You can work around this by avoiding adding any functional methods to the class ie. avoid self.foo = F.gelu