Oneflow-Inc / oneflow

OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.
http://www.oneflow.org
Apache License 2.0
5.87k stars 667 forks source link

Load model's parameters like torchvision.models #3721

Open ximu1211 opened 3 years ago

ximu1211 commented 3 years ago

If I want to use the data like weight or bias ,like torchvision.models, what should I do?

ximu1211 commented 3 years ago

So far, I have not found a valid example to use pretrained models to init a new layer

daquexian commented 3 years ago

If I want to use the data like weight or bias ,like torchvision.models, what should I do?

Do you mean read weight from a pretrained model? Currently, the weight can only be saved to or loaded from the disk. There is a work-in-progress (https://github.com/Oneflow-Inc/oneflow/pull/3540) that makes model loading/saving much more flexible. After it is merged (maybe next week) you can read weights from pretrained models and use them to init another new layer.

ximu1211 commented 3 years ago

If I want to use the data like weight or bias ,like torchvision.models, what should I do?

Do you mean read weight from a pretrained model? Currently, the weight can only be saved to or loaded from the disk. There is a work-in-progress (#3540) that makes model loading/saving much more flexible. After it is merged (maybe next week) you can read weights from pretrained models and use them to init another new layer.

Actually,I download a pretrained model from https://github.com/Oneflow-Inc/OneFlow-Benchmark/tree/master/Classification/cnns. It's a Oneflow-Resnet model. I can use checkpoint.load ,but I don't know how to use the weight from it to init my own layers

daquexian commented 3 years ago

Actually,I download a pretrained model from https://github.com/Oneflow-Inc/OneFlow-Benchmark/tree/master/Classification/cnns. It's a Oneflow-Resnet model. I can use checkpoint.load ,but I don't know how to use the weight from it to init my own layers

For now you have to rename the pretrain weight file.

For example, if you download the pretrained weights to "weight_dir", and want to use "conv1_weight" of the oneflow-resnet model to init "my_conv1_weight" of your own model, you need to rename the file "weight_dir/conv1_weight" to "weight_dir/my_conv1_weight" and load the pretrained weights by checkpoint.load(weight_dir).

Since it is so complicated, I recommend waiting for the merging of #3540 .