nianticlabs / ace

[CVPR 2023 - Highlight] Accelerated Coordinate Encoding (ACE): Learning to Relocalize in Minutes using RGB and Poses
https://nianticlabs.github.io/ace
Other
359 stars 34 forks source link

Support scriptable model export by converting to nn.ModuleList #22

Open daniel-sudz opened 1 year ago

daniel-sudz commented 1 year ago

Closes #20

Here is a picture of the dumped parameter dictionary to show the naming change:

original-dict

Before (origin/main)

new-dict

After (daniel-sudz/scriptable_model)

daniel-sudz commented 1 year ago

So I was able to get scriptable model export running but unfortunately not with metal support on IOS because of this issue: https://github.com/pytorch/pytorch/issues/69609 (it looks like you need to compile pytorch from source).

CPU based export seemed to work fine with the following for me:

""" 
    Converts the ACE model for mobile usage
"""
def save_model_for_mobile(ace_encoder_pretrained: Path, trained_weights: Path):     
    encoder_state_dict = torch.load(ace_encoder_pretrained, map_location="cpu")
    head_network_dict = torch.load(trained_weights, map_location="cpu")

    device = torch.device("cuda")
    network = Regressor.create_from_split_state_dict(encoder_state_dict, head_network_dict)
    network = network.to(device)
    network.eval()

    scripted_module = torch.jit.script(network)
    # it looks it's not trivial to optimize for mobile gpu right because this issue:    
    optimized_model = optimize_for_mobile(scripted_module, backend='CPU')
    optimized_model.save(trained_weights.parent / "mobile.model.pt")
    optimized_model._save_for_lite_interpreter((trained_weights.parent / "mobile.model.ptl").as_posix())