Verified-Intelligence / alpha-beta-CROWN

alpha-beta-CROWN: An Efficient, Scalable and GPU Accelerated Neural Network Verifier (winner of VNN-COMP 2021, 2022, 2023, and 2024)
Other
235 stars 58 forks source link

Error when changing last layer's weight #42

Open mhmd97z opened 11 months ago

mhmd97z commented 11 months ago

Hi, I figured CROWN does not support having the ReLU function as last year (it throws some error), so I thought I could add an identity Linear layer after ReLU as follows:

class MyModel(nn.ModuleList):
        def __init__(self, device=torch.device("cpu")):
            super(MyModel, self).__init__()        
            self.to(device)

            self.af = nn.ReLU()
            self.lin1 = nn.Linear(19, 32)
            self.lin2 = nn.Linear(32, 32)
            self.out = nn.Linear(32, 15)

            self.c2d_1 = torch.nn.Conv1d(1, 24, 15)
            self.c2d_2 = torch.nn.Conv1d(24, 12, 1)
            self.c2d_3 = torch.nn.Conv1d(12, 6, 1)
            self.c2d_4 = torch.nn.Conv1d(6, 3, 1)
            self.c2d_5 = torch.nn.Conv1d(3, 1, 1)

            self.eye = torch.nn.Linear(1, 1)
            self.eye.weight = torch.nn.Parameter(torch.ones_like(self.eye.weight))
            self.eye.bias = torch.nn.Parameter(torch.tensor([0.], requires_grad=True))       

       def forward(self, obs):
            obs = self.af(self.lin1(obs))
            obs = self.af(self.lin2(obs))
            logits = self.out(obs)
            after_conv1 = self.af(self.c2d_1(logits))
            after_conv2 = self.af(self.c2d_2(after_conv1))
            after_conv3 = self.af(self.c2d_3(after_conv2))
            after_conv4 = self.af(self.c2d_4(after_conv3))
            after_conv5 = self.af(self.c2d_5(after_conv4))
            return self.eye(after_conv5)[0]

Nonetheless, it gives me the following error:

Traceback (most recent call last):
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 612, in <module>
    abcrown.main()
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 591, in main
    verified_status = self.complete_verifier(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 416, in complete_verifier
    l, nodes, ret = self.bab(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 241, in bab
    result = general_bab(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 340, in general_bab
    global_lb = act_split_round(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 165, in act_split_round
    split_domain(net, domains, d, batch, impl_params=impl_params,
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 70, in split_domain
    branching_heuristic.get_branching_decisions(
  File "/home/mzi/anaconda3/envs/crown/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/heuristics/kfsb.py", line 143, in get_branching_decisions
    k_ret_lbs = self.net.update_bounds(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/beta_CROWN_solver.py", line 303, in update_bounds
    lb, _, = self.net.compute_bounds(
  File "/home/mzi/anaconda3/envs/crown/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 1206, in compute_bounds
    return self._compute_bounds_main(C=C,
  File "/home/mzi/anaconda3/envs/crown/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 1311, in _compute_bounds_main
    ret = self.backward_general(
  File "/home/mzi/anaconda3/envs/crown/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/backward_bound.py", line 324, in backward_general
    lb, ub = concretize(self, batch_size, output_dim, lb, ub,
  File "/home/mzi/anaconda3/envs/crown/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/backward_bound.py", line 671, in concretize
    lA = roots[i].lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_lower else None
RuntimeError: shape '[1, 4, -1]' is invalid for input of size 1

P.S: the error is gone when I remove this line of parameter setting:

self.eye.weight = torch.nn.Parameter(torch.ones_like(self.eye.weight))

shizhouxing commented 11 months ago

Does self.eye(after_conv5)[0] alter the batch dimension (as mentioned in isssue #41)? Can you try returning self.eye(after_conv5) instead?

mhmd97z commented 11 months ago

Thanks for the prompt response. I just tried that and got the following error:

Traceback (most recent call last):
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 612, in <module>
    abcrown.main()
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 591, in main
    verified_status = self.complete_verifier(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 416, in complete_verifier
    l, nodes, ret = self.bab(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 241, in bab
    result = general_bab(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 340, in general_bab
    global_lb = act_split_round(
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 165, in act_split_round
    split_domain(net, domains, d, batch, impl_params=impl_params,
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 70, in split_domain
    branching_heuristic.get_branching_decisions(
  File "/home/mzi/anaconda3/envs/crown/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/heuristics/kfsb.py", line 162, in get_branching_decisions
    (k_ret_lbs.view(-1) - torch.cat([mask_score, mask_itb]).repeat(2) * 999999).reshape(2, -1),
RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 0