huanzhang12 / CROWN-IBP

Certified defense to adversarial examples using CROWN and IBP. Also includes GPU implementation of CROWN verification algorithm (in PyTorch).
https://openreview.net/pdf?id=Skxuk1rFwB
BSD 2-Clause "Simplified" License
93 stars 14 forks source link

二分查找lp #4

Closed miaoxiaodaiblack closed 3 years ago

miaoxiaodaiblack commented 3 years ago

您好! 我想用‘simple_verification.py’中代码计算的lp和Up实现您在CROWN那篇论文中实现的二分查找法,并尝试更改如下: def test(input,model): eps = 0 gap_gx = 100 eps_LB = -1 eps_UB = 1 counter = 0 is_pos = True is_neg = True

# perform binary search
eps_gx_UB = 1000000.0
eps_gx_LB = 0.0
is_pos = True
is_neg = True
# eps = eps_gx_LB*2
# eps = args.eps

while eps_gx_UB - eps_gx_LB > 0.00001:
    ptb = PerturbationLpNorm(norm=2, eps=eps)
    image = BoundedTensor(input, ptb)
    pred = model(image)
    label = torch.argmax(pred, dim=1).cpu().numpy()
    # for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)']:
    lb, ub = model.compute_bounds(x=(image,), method='IBP+backward')
    gap_gx = torch.min(lb)
    lb = lb.detach().cpu().numpy()
    ub = ub.detach().cpu().numpy()
    print("Bounding method:", method)
    for i in range(N):
        print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i]))
        for j in range(n_classes):
            indicator = '(ground-truth)' if j == true_label[i] else ''
            print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}".format(
                j=j, l=lb[i][j], u=ub[i][j], ind=indicator))
    print()
    if gap_gx > 0:
        if gap_gx < 0.01:
            eps_gx_LB = eps
            return eps
            break
        if is_pos:  # so far always > 0, haven't found eps_UB
            eps_gx_LB = eps
            eps *= 10
        else:
            eps_gx_LB = eps
            eps = (eps_gx_LB + eps_gx_UB) / 2
        is_neg = False
    else:
        if is_neg:  # so far always < 0, haven't found eps_LB
            eps_gx_UB = eps
            eps /= 10
        else:
            eps_gx_UB = eps
            eps = (eps_gx_LB + eps_gx_UB) / 2
        is_pos = False
    counter += 1
    if counter >= 500:
        return eps
        break
print("[L2][binary search] step = {}, eps = {:.5f}, gap_gx = {:.2f}".format(counter, eps, gap_gx))

但是得到的并不是想要的结果,请问是哪里出了问题吗?

huanzhang12 commented 3 years ago

See #5