Verified-Intelligence / alpha-beta-CROWN

alpha-beta-CROWN: An Efficient, Scalable and GPU Accelerated Neural Network Verifier (winner of VNN-COMP 2021, 2022, 2023, and 2024)
Other
243 stars 60 forks source link

Out-Of-Memory Error / "Killed" #58

Closed jannickstrobel closed 3 months ago

jannickstrobel commented 6 months ago

Describe the bug I am trying to verify a specification based on a metric by joining the original network and the metric into a new network.

After a while, BnB breaks with "Killed". According to the logs, the process gets killed by the linux OOM-killer.

Is there a configuration that enables the verification of bigger models? I am aware that my network is hard to verify due to the big input space, but maybe there is an option to use external memory?

To Reproduce

The specification asserts that the input for every input variable is in range (-1, 1) and the output is <= 0

(assert (<= X_0 1))
(assert (>= X_0 -1))
(assert (<= X_1 1))
(assert (>= X_1 -1))
...
(assert (<= Y_0 0))

Network:

...
self.sequential = nn.Sequential(
            nn.Linear(input_size, relu_layer_size),
            nn.ReLU(),
            nn.Linear(relu_layer_size, relu_layer_size),
            nn.ReLU(),
            nn.Linear(relu_layer_size, 43)
        )

...

def forward(self, input): 
        def metric_mse(input1, input2):
            input1 = input1.to(device='cuda')
            input2 = input2.to(device='cuda')
            temp = (input1 - input2) ** 2
            mse = torch.mean(temp, dim=tuple(range(1, input1.dim())), keepdim=True)
            similarity = 1 - mse
            return similarity

        def join(input):
            metric_val  = metric_mse(input, self.reference)
            pred        = self.sequential(input).to(device='cuda')
            diff        = pred[0, self.cl] - pred
            min_val     = torch.min(diff, dim=1, keepdim=True)[0]
            met_val     = self.threshold - metric_val
            concat      = torch.cat((min_val, met_val), dim=1)
            output      = torch.max(concat, dim=1, keepdim=True)[0]
            return output

System configuration:

Error trace

BaB round 299                                                                                                                                                  │
batch: 2048                                                                                                                                                    │
Start filtering...                                                                                                                                             │
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 37.06it/s]│
kfsb scores (first 10): tensor([[-1.06886995, -0.81120229, -1.83464122, -0.79619068, -1.05805361,                                                              │
         -0.76246947, -0.60957587, -0.60908198, -0.45265666, -1.59237599],                                                                                     │
        [-1.06886995, -0.81120229, -1.83464122, -0.79619068, -1.05805361,                                                                                      │
         -0.76246947, -0.60957587, -0.60908198, -0.45265666, -1.59237599],                                                                                     │
        [-1.06886995, -0.81120229, -1.83464122, -0.79619068, -1.05805361,                                                                                      │
         -0.76246947, -0.60957587, -0.60908198, -0.45265666, -1.59237599]],                                                                                    │
       device='cuda:0')                                                                                                                                        │
kfsb choice (first 10): tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')                                                                                │
Filtering time: 0.5150580406188965                                                                                                                             │
splitting decisions:                                                                                                                                           │
split level 0: [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2]                         │
Time: prepare 0.8080    bound 0.3880    transfer 0.0025    finalize 0.3109    func 1.5112                                                                      │
Accumulated time: func 596.2728    prepare 279.8262    bound 152.9367    transfer 1.1092    finalize 154.5267                                                  │
Current worst splitting domains lb-rhs (depth):                                                                                                                │
-0.04100 (64), -0.04100 (15), -0.04100 (71), -0.04100 (38), -0.04100 (34), -0.04100 (46), -0.04100 (39), -0.04100 (78), -0.04100 (73), -0.04100 (19), -0.04100 │
(43), -0.04100 (51), -0.04100 (53), -0.04100 (63), -0.04100 (67), -0.04100 (57), -0.04100 (32), -0.04100 (39), -0.04100 (32), -0.04100 (41),                   │
Length of domains: 589674                                                                                                                                      │
Time: pickout 0.1445    decision 0.5791    set_bounds 0.4862    solve 1.5135    add 0.0661                                                                     │
Accumulated time: pickout 35.6028    decision 179.8644    set_bounds 220.6298    solve 597.1098    add 182.1249                                                │
Sorting batched domains takes 6.2405476570129395 seconds.                                                                                                      │
Current (lb-rhs): -0.04099998623132706                                                                                                                         │
1179498 domains visited                                                                                                                                        │
Cumulative time: 2191.093161344528                                                                                                                             │
                                                                                                                                                               │
BaB round 300                                                                                                                                                  │
batch: 2048                                                                                                                                                    │
Start filtering...                                                                                                                                             │
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 15.62it/s]│
kfsb scores (first 10): tensor([[-1.03997588, -0.33855882, -0.52589869, -1.07941067, -0.91442627,                                                              │
         -0.56939507, -0.48885238, -0.71086031, -0.99345517, -0.88745856],                                                                                     │
        [-1.03997588, -0.33855882, -0.52589869, -1.07941067, -0.91442627,                                                                                      │
         -0.56939507, -0.48885238, -0.71086031, -0.99345517, -0.88745856],                                                                                     │
        [-1.03997588, -0.33855885, -0.52589869, -1.07941067, -0.91442627,                                                                                      │
         -0.56939507, -0.48885238, -0.71086031, -0.99345517, -0.88745850]],                                                                                    │
       device='cuda:0')                                                                                                                                        │
kfsb choice (first 10): tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 2], device='cuda:0')                                                                                │
Filtering time: 1.3492143154144287                                                                                                                             │
splitting decisions:                                                                                                                                           │
split level 0: [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 2] [/input, 0]                         │
Time: prepare 1.8817    bound 1.0160    transfer 0.0195    finalize 1.2457    func 4.1651                                                                      │
Accumulated time: func 600.4379    prepare 281.7092    bound 153.9527    transfer 1.1287    finalize 155.7724                                                  │
Current worst splitting domains lb-rhs (depth):                                                                                                                │
-0.04100 (29), -0.04100 (33), -0.04100 (41), -0.04100 (25), -0.04100 (89), -0.04100 (31), -0.04100 (36), -0.04100 (30), -0.04100 (57), -0.04100 (56), -0.04100 │
(41), -0.04100 (38), -0.04100 (66), -0.04100 (50), -0.04100 (84), -0.04100 (68), -0.04100 (26), -0.04100 (34), -0.04100 (77), -0.04100 (16),                   │
Length of domains: 591722                                                                                                                                      │
Time: pickout 0.3452    decision 1.4593    set_bounds 1.2116    solve 4.1870    add 15.8896                                                                    │
Accumulated time: pickout 35.9480    decision 181.3237    set_bounds 221.8413    solve 601.2968    add 198.0145                                                │
Killed                                                    

Additional context A error message that indicates the program broke due to the OOM-killer would be helpful.

shizhouxing commented 3 months ago

Hi @jannickstrobel, it implies that there are too many unverified domains in branch-and-bound, and thus the verification will unlikely succeed for this model. The input range might be too large ([-1, 1] for each dimension).