Open daniel-sudz opened 1 year ago
So I was able to get scriptable model export running but unfortunately not with metal support on IOS because of this issue: https://github.com/pytorch/pytorch/issues/69609 (it looks like you need to compile pytorch from source).
CPU based export seemed to work fine with the following for me:
"""
Converts the ACE model for mobile usage
"""
def save_model_for_mobile(ace_encoder_pretrained: Path, trained_weights: Path):
encoder_state_dict = torch.load(ace_encoder_pretrained, map_location="cpu")
head_network_dict = torch.load(trained_weights, map_location="cpu")
device = torch.device("cuda")
network = Regressor.create_from_split_state_dict(encoder_state_dict, head_network_dict)
network = network.to(device)
network.eval()
scripted_module = torch.jit.script(network)
# it looks it's not trivial to optimize for mobile gpu right because this issue:
optimized_model = optimize_for_mobile(scripted_module, backend='CPU')
optimized_model.save(trained_weights.parent / "mobile.model.pt")
optimized_model._save_for_lite_interpreter((trained_weights.parent / "mobile.model.ptl").as_posix())
Closes #20
Here is a picture of the dumped parameter dictionary to show the naming change:
Before (origin/main)
After (daniel-sudz/scriptable_model)