mit-han-lab / once-for-all

[ICLR 2020] Once for All: Train One Network and Specialize it for Efficient Deployment
https://ofa.mit.edu/
MIT License
1.89k stars 333 forks source link

How to export one of the subnets? #51

Open gianderiu opened 3 years ago

gianderiu commented 3 years ago

I'd like to export as an ONNX file or as a pth file + net Class some of the subnets. How can I do it?

Darshcg commented 3 years ago

Yes, I do have the same Question.

Bixiii commented 3 years ago

You can use the torch.onnx package.

When you have your trained OFA network as ofa_network you can sample a random subnet with ofa_network.sample_active_subnet() than you can cut that network with subnet = ofa_network.get_active_subnet(). Don't forget to reset the batch norm statistics reset_running_statistics(net=subnet) . Then you can export it like any model.

torch.onnx.export(
    subnet,
    torch.randn(1, 3, 224, 224),
    'model_name.onnx',
    export_params=True,
)
detectRecog commented 3 years ago

You can use the torch.onnx package.

When you have your trained OFA network as ofa_network you can sample a random subnet with ofa_network.sample_active_subnet() than you can cut that network with subnet = ofa_network.get_active_subnet(). Don't forget to reset the batch norm statistics reset_running_statistics(net=subnet) . Then you can export it like any model.

torch.onnx.export(
    subnet,
    torch.randn(1, 3, 224, 224),
    'model_name.onnx',
    export_params=True,
)

How to extract the subnet according to these preset configs like "pixel2_lat@25ms_top1@71.5_finetune@25"? After I load the big OFA pretrained on a custom dataset, I can not figure out how to get the subnet according to the preset config.

Bixiii commented 3 years ago

You can use the torch.onnx package. When you have your trained OFA network as ofa_network you can sample a random subnet with ofa_network.sample_active_subnet() than you can cut that network with subnet = ofa_network.get_active_subnet(). Don't forget to reset the batch norm statistics reset_running_statistics(net=subnet) . Then you can export it like any model.

torch.onnx.export(
    subnet,
    torch.randn(1, 3, 224, 224),
    'model_name.onnx',
    export_params=True,
)

How to extract the subnet according to these preset configs like "pixel2_lat@25ms_top1@71.5_finetune@25"? After I load the big OFA pretrained on a custom dataset, I can not figure out how to get the subnet according to the preset config.

I' have not tried this myself, but I would try something like this

Either:

Write a script that generates an architecture configuration as needed by this function from the net.config file. Then do everything as described previously.

Or:

Create a model with the desired architecture, as it is done here. In your example for net_id use "pixel2_lat@25ms_top1@71.5_finetune@25".

Now a little work is needed. You need to load the weights from your OFA network into the subnetwork. Therefore, you need to write a function similar to this. But only load the values from your OFA network that are needed in the subnet architecture. You also have to reset_running_statistics() for the subnet.

Then you can export the subnetwork as described previously.