I encountered several problems during the reproduction process(data_synthesis_with_svd_with_db_with_all_statistic.py):
lack of module name
# started from line 54
for i, (_model_teacher) in enumerate(model_teacher):
print(_model_teacher)
for name, module in _model_teacher.named_modules():
if isinstance(module, nn.BatchNorm2d):
loss_r_feature_layers[i].append(BNFeatureHook(module,training_momentum=args.training_momentum))
elif isinstance(module, nn.Conv2d):
_hook_module = ConvFeatureHook(module, save_path=args.statistic_path,
name=str(_model_teacher.__class__.__name__) + "=" + name,
gpu=gpu, training_momentum=args.training_momentum, drop_rate=args.drop_rate)
_hook_module.set_hook(pre=True)
load_tag = load_tag & _hook_module.load_tag
loss_r_feature_layers[i].append(_hook_module)
print(load_tag)
In the initialization of BNFeatureHook, it requires name attribute.
lack of targets name
In the pre_hook_fn and post_hook_fn of ConvFeatureHook class, the self.targets is needed, but the code in data_synthesis_with_svd_with_db_with_all_statistic.py didn't set this attribute. I check the recover.py file and it set self.targets=targets in the loop below. And in the code of data_synthesis_with_svd_with_db_with_all_statistic.py, this attribute wasn't set in the similar loop.
# from file recover.py
# started from line 171
for j, _model_teacher in enumerate(model_teacher):
if not load_tag_dict[j]:
print(f"conduct backbone {args.aux_teacher[j]} statistics")
for i, (data, targets) in tqdm(enumerate(train_loader)):
data = data.cuda(gpu)
targets = targets.cuda(gpu)
for _loss_t_feature_layer in loss_r_feature_layers[j]:
_loss_t_feature_layer.set_label(targets) # <- set self.targets
_ = _model_teacher(data)
for _loss_t_feature_layer in loss_r_feature_layers[j]:
_loss_t_feature_layer.save()
# started from line 244
for mod in loss_r_feature_layers[id]:
mod.set_label(targets) # <- set self.targets
incomplete aux_teacher
As said in the paper, the cifar's aux_teachers include {ResNet-18, ConvNet-W128, MobileNet-V2, WRN-162, ShuffleNet-V2, ConvNet-D1, ConvNet-D2, ConvNet-W32}. However, neither the code in squeeze dir nor recover dir include 8 models above. The scripts is mismatched with the paper and the code didn't prepare for all the 8 backbones too.
# recover>data_synthesis_with_svd_with_db_with_all_statistic.py, from line 331
# only 5 backbones too
aux_teacher = ["ResNet18", "ConvNetW128", "MobileNetV2", "WRN_16_2", "ShuffleNetV2_0_5"]
model_teacher = []
for name in aux_teacher:
if name == "ConvNetW128":
model = ti_get_network(name, channel=3, num_classes=10, im_size=(32, 32), dist=False)
else:
model = ti_models.model_dict[name](num_classes=10)
model_teacher.append(model)
checkpoint = torch.load(
os.path.join(args.pre_train_path, "CIFAR-10", name, f"squeeze_{name}.pth"),
map_location="cpu")
model_teacher[-1].load_state_dict(checkpoint)
All the bug I've tried to fix and made it run finally. However, I'm not sure whether I realize the code and fix it in a right way. So could you please fix these issues?
I encountered several problems during the reproduction process(
data_synthesis_with_svd_with_db_with_all_statistic.py
):In the initialization of
BNFeatureHook
, it requiresname
attribute.In the
pre_hook_fn
andpost_hook_fn
ofConvFeatureHook
class, theself.targets
is needed, but the code indata_synthesis_with_svd_with_db_with_all_statistic.py
didn't set this attribute. I check therecover.py
file and it setself.targets=targets
in the loop below. And in the code ofdata_synthesis_with_svd_with_db_with_all_statistic.py
, this attribute wasn't set in the similar loop.As said in the paper, the cifar's aux_teachers include
{ResNet-18, ConvNet-W128, MobileNet-V2, WRN-162, ShuffleNet-V2, ConvNet-D1, ConvNet-D2, ConvNet-W32}
. However, neither the code insqueeze
dir norrecover
dir include 8 models above. The scripts is mismatched with the paper and the code didn't prepare for all the 8 backbones too.All the bug I've tried to fix and made it run finally. However, I'm not sure whether I realize the code and fix it in a right way. So could you please fix these issues?
Thanks for your great work and sharing.