Closed wnov closed 3 years ago
非常棒的论文,我想复现下您的论文,但是似乎搜索不到足够好的candidate,有两个问题
search.py对main函数while循环,导致数据集部分重复计算是什么考虑呢?
除此之外,我执行search.py 搜索的过程中,找不到评分高于0.05的candidate,执行时间超过了1GPUday,个人感觉随机搜索的策略似乎并不能保证在确定时间内找到足够好的candidate
我将while循环移到main函数内部以加速,修改的代码如下 `def main(opt):
# os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_ids)[1:-1] if torch.cuda.is_available(): device = "cuda" torch.backends.cudnn.benchmark = True else: device = "cpu" ##################### Get Dataloader #################### dataloader_train, dataloader_val = custom_get_dataloaders(opt) train_data = [] for index, sample in enumerate(tqdm(dataloader_train, leave=False)): train_data.append(sample) if index > 100: break # dummy_input is sample input of dataloaders if hasattr(dataloader_val, "dataset"): dummy_input = dataloader_val.dataset.__getitem__(0) dummy_input = dummy_input[0] dummy_input = dummy_input.unsqueeze(0) else: # for imagenet dali loader dummy_input = torch.rand(1, 3, 224, 224) while True: ##################### Create Baseline Model #################### net = ModelWrapper(opt) net.load_checkpoint(opt.checkpoint) flops_before, params_before = model_summary(net.get_compress_part(), dummy_input) ##################### Pruning Strategy Generation ############### compression_scheduler = distiller.file_config( net.get_compress_part(), net.optimizer, opt.compress_schedule_path ) num_layer = len(compression_scheduler.policies[1]) channel_config = get_pruning_strategy(opt, num_layer) # pruning strategy compression_scheduler = random_compression_scheduler( compression_scheduler, channel_config ) ###### Adaptive-BN-based Candidate Evaluation of Pruning Strategy ### try: thinning(net, compression_scheduler, input_tensor=dummy_input) except: print('[WARNING] This pruning strategy is invalid for distiller thinning module, pass it.') continue flops_after, params_after = model_summary(net.get_compress_part(), dummy_input) ratio = flops_after / flops_before print("FLOPs ratio:", ratio) if ratio < opt.flops_target - 0.01 or ratio > opt.flops_target + 0.01: # illegal pruning strategy continue net = net.to(device) net.parallel(opt.gpu_ids) net.get_compress_part().train() with torch.no_grad(): for index, sample in enumerate(tqdm(train_data, leave=True)): _ = net.get_loss(sample) strategy_score = net.get_eval_scores(dataloader_val)["accuracy"] #################### Save Pruning Strategy and Score ######### log_file = open(opt.output_file, "a+") log_file.write("{} {} ".format(strategy_score, ratio)) for item in channel_config: log_file.write("{} ".format(str(item))) log_file.write("\n") log_file.close() print("Eval Score:{}".format(strategy_score)) if strategy_score >= 0.141: return`
你好!非常感谢你对我们工作的关注。
非常棒的论文,我想复现下您的论文,但是似乎搜索不到足够好的candidate,有两个问题
search.py对main函数while循环,导致数据集部分重复计算是什么考虑呢?
除此之外,我执行search.py 搜索的过程中,找不到评分高于0.05的candidate,执行时间超过了1GPUday,个人感觉随机搜索的策略似乎并不能保证在确定时间内找到足够好的candidate
我将while循环移到main函数内部以加速,修改的代码如下 `def main(opt):
basic settings