Open mhmd97z opened 11 months ago
Does self.eye(after_conv5)[0]
alter the batch dimension (as mentioned in isssue #41)? Can you try returning self.eye(after_conv5)
instead?
Thanks for the prompt response. I just tried that and got the following error:
Traceback (most recent call last):
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 612, in <module>
abcrown.main()
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 591, in main
verified_status = self.complete_verifier(
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 416, in complete_verifier
l, nodes, ret = self.bab(
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/abcrown.py", line 241, in bab
result = general_bab(
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 340, in general_bab
global_lb = act_split_round(
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 165, in act_split_round
split_domain(net, domains, d, batch, impl_params=impl_params,
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/bab.py", line 70, in split_domain
branching_heuristic.get_branching_decisions(
File "/home/mzi/anaconda3/envs/crown/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/mzi/sys-rl-verif/alpha-beta-CROWN/complete_verifier/heuristics/kfsb.py", line 162, in get_branching_decisions
(k_ret_lbs.view(-1) - torch.cat([mask_score, mask_itb]).repeat(2) * 999999).reshape(2, -1),
RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 0
Hi, I figured CROWN does not support having the ReLU function as last year (it throws some error), so I thought I could add an identity Linear layer after ReLU as follows:
Nonetheless, it gives me the following error:
P.S: the error is gone when I remove this line of parameter setting:
self.eye.weight = torch.nn.Parameter(torch.ones_like(self.eye.weight))