PaddlePaddle / PaddleSlim

PaddleSlim is an open-source library for deep model compression and architecture search.
https://paddleslim.readthedocs.io/zh_CN/latest/
Apache License 2.0
1.56k stars 345 forks source link

使用paddleslim模型动态剪枝后,如何保存模型呢 #1861

Open MiXianif opened 8 months ago

MiXianif commented 8 months ago

我是用paddle.save(net.state_dict(),path)来进行保存,发现剪枝后模型比剪枝前的模型都要大,但是通过paddle.summary(net, (1, 3, 32, 32))查看模型确实变小了

paddlepaddle=2.6.0 paddleslim=2.6.0

以下是我的代码: from future import print_function import paddle from paddle.vision.models import mobilenet_v1 net = mobilenet_v1(pretrained=False) paddle.summary(net, (1, 3, 32, 32))

import paddle.vision.transforms as T transform = T.Compose([ T.Transpose(), T.Normalize([127.5], [127.5]) ]) train_dataset = paddle.vision.datasets.Cifar10(mode="train", backend="cv2",transform=transform) val_dataset = paddle.vision.datasets.Cifar10(mode="test", backend="cv2",transform=transform)

print(f'train samples count: {len(train_dataset)}')

print(f'val samples count: {len(val_dataset)}')

for data in train_dataset:

print(f'image shape: {data[0].shape}; label: {data[1]}')

break

from paddle.static import InputSpec as Input optimizer = paddle.optimizer.Momentum( learning_rate=0.1, parameters=net.parameters())

inputs = [Input([None, 3, 32, 32], 'float32', name='image')] labels = [Input([None, 1], 'int64', name='label')]

model = paddle.Model(net, inputs, labels)

model.prepare( optimizer, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy(topk=(1, 5)))

model.fit(train_dataset, epochs=2, batch_size=128, verbose=1) result = model.evaluate(val_dataset,batch_size=128, log_freq=10, verbose=0) paddle.save(net.state_dict(), "./runs/FPGMFilterPruner/model.pdparams")

from paddleslim.dygraph import L1NormFilterPruner, FPGMFilterPruner pruner = FPGMFilterPruner(net, [1, 3, 32, 32], opt=optimizer)

def eval_fn(): result = model.evaluate( val_dataset, batch_size=128, verbose=0) return result['acc_top1']

pruner.sensitive(eval_func=eval_fn, sen_file="./sen.pickle")

from paddleslim.analysis import dygraph_flops flops = dygraph_flops(net, [1, 3, 32, 32]) print(f"FLOPs before pruning: {flops}")

plan = pruner.sensitive_prune(0.4, skip_vars=["conv2d_26.w"]) paddle.save(net.state_dict(), "./runs/FPGMFilterPruner/pruning.pdparams")

flops = dygraph_flops(net, [1, 3, 32, 32]) print(f"FLOPs after pruning: {flops}") print(f"Pruned FLOPs: {round(plan.pruned_flops*100, 2)}%")

result = model.evaluate(val_dataset,batch_size=128, log_freq=10, verbose=0) print(f"before fine-tuning: {result}")

model.fit(train_dataset, epochs=2, batch_size=128, verbose=1) result = model.evaluate(val_dataset,batch_size=128, log_freq=10, verbose=0) print(f"after fine-tuning: {result}") paddle.save(net.state_dict(), "./runs/FPGMFilterPruner/pruned.pdparams")

paddle.summary(net, (1, 3, 32, 32))