anonymous47823493 / EagleEye

(ECCV'2020 Oral)EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning
304 stars 68 forks source link

search candidate problems #37

Closed wnov closed 3 years ago

wnov commented 3 years ago

非常棒的论文,我想复现下您的论文,但是似乎搜索不到足够好的candidate,有两个问题

我将while循环移到main函数内部以加速,修改的代码如下 `def main(opt):

basic settings

# os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_ids)[1:-1]

if torch.cuda.is_available():
    device = "cuda"
    torch.backends.cudnn.benchmark = True
else:
    device = "cpu"
##################### Get Dataloader ####################
dataloader_train, dataloader_val = custom_get_dataloaders(opt)

train_data = []
for index, sample in enumerate(tqdm(dataloader_train, leave=False)):
    train_data.append(sample)
    if index > 100:
        break

# dummy_input is sample input of dataloaders
if hasattr(dataloader_val, "dataset"):
    dummy_input = dataloader_val.dataset.__getitem__(0)
    dummy_input = dummy_input[0]
    dummy_input = dummy_input.unsqueeze(0)
else:
    # for imagenet dali loader
    dummy_input = torch.rand(1, 3, 224, 224)

while True:
    #####################  Create Baseline Model  ####################
    net = ModelWrapper(opt)
    net.load_checkpoint(opt.checkpoint)
    flops_before, params_before = model_summary(net.get_compress_part(), dummy_input)

    #####################  Pruning Strategy Generation ###############
    compression_scheduler = distiller.file_config(
        net.get_compress_part(), net.optimizer, opt.compress_schedule_path
    )
    num_layer = len(compression_scheduler.policies[1])

    channel_config = get_pruning_strategy(opt, num_layer)  # pruning strategy

    compression_scheduler = random_compression_scheduler(
        compression_scheduler, channel_config
    )

    ###### Adaptive-BN-based Candidate Evaluation of Pruning Strategy ###
    try:
        thinning(net, compression_scheduler, input_tensor=dummy_input)
    except:
        print('[WARNING] This pruning strategy is invalid for distiller thinning module, pass it.')
        continue

    flops_after, params_after = model_summary(net.get_compress_part(), dummy_input)
    ratio = flops_after / flops_before
    print("FLOPs ratio:", ratio)
    if ratio < opt.flops_target - 0.01 or ratio > opt.flops_target + 0.01:
        # illegal pruning strategy
        continue
    net = net.to(device)
    net.parallel(opt.gpu_ids)
    net.get_compress_part().train()
    with torch.no_grad():
        for index, sample in enumerate(tqdm(train_data, leave=True)):
            _ = net.get_loss(sample)

    strategy_score = net.get_eval_scores(dataloader_val)["accuracy"]

    #################### Save Pruning Strategy and Score #########
    log_file = open(opt.output_file, "a+")
    log_file.write("{} {} ".format(strategy_score, ratio))

    for item in channel_config:
        log_file.write("{} ".format(str(item)))
    log_file.write("\n")
    log_file.close()
    print("Eval Score:{}".format(strategy_score))

    if strategy_score >= 0.141:
        return`
Bowenwu1 commented 3 years ago

你好!非常感谢你对我们工作的关注。

  1. 我们在代码中对重复计算主要是为了保证每次评估剪枝策略时使用相同的数据,保证每次评估效果相同。关于执行效率,我们由于时间的限制,没有进行特别多的优化。你所做的修改,可以在保证每次数据相同的前提下,减少每次加载dataloader的耗时,是非常不错的改进!
  2. 随机搜索的搜索开销确实可能会因为随机数初始化种子等原因产生较大的震荡,这的确是随机搜索所存在的问题。建议在实际使用过程中使用多张GPU并行搜索,提升搜索效率,另外也推荐在实际场景中尝试使用启发式搜索提升搜索效率。