Verified-Intelligence / alpha-beta-CROWN

alpha-beta-CROWN: An Efficient, Scalable and GPU Accelerated Neural Network Verifier (winner of VNN-COMP 2021, 2022, 2023, and 2024)
Other
243 stars 60 forks source link

Issue with type mismatch in ```get_sparse_C``` and ```backward_general``` functions #80

Open 929937690 opened 4 days ago

929937690 commented 4 days ago

I encountered a type mismatch issue when using the alpha-beta crown method. Specifically, when calling get_sparse_C with unstable_size > crown_batch_size, the variable newC is set to the string 'Patches' instead of an actual Patches object. This causes an AssertionError when performing type checking in backward_general since the expected type is Patches but newC is a string.

The error message produced is: AssertionError: <class 'str'>

I believe this issue could also arise if newC is set to the string 'eye' rather than an eyeC object, leading to similar assertion errors.

Is this behavior intentional, or would there be a recommended workaround to prevent these type mismatches? I’d appreciate any insights into handling these cases, as currently, only the 'Patches' string assignment has triggered the error for me.

# bound_general.py
def compute_intermediate_bounds()
    ...
    sparse_C = self.get_sparse_C(node, ref_intermediate)
    ...
    ... = self.backward_general()
    ...

# backward_bound.py
def get_sparse_C()
    ...
    if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int(
            os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0:
        ...
            if not reduced_dim:
                if dim > crown_batch_size:
                    newC = 'eye'
           else:
                newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device)
    elif node.patches_start and node.mode == "patches":
        if sparse_intermediate_bounds:
            ...
            elif unstable_size > crown_batch_size: 
                    newC = 'Patches'
                    reduced_dim = True
    ...
    else:
        ...

        if not reduced_dim:
            ...
            if dim > crown_batch_size:
                newC = 'eye'
            else:
                newC = torch.eye(dim, device=self.device).unsqueeze(0).expand(
                    batch_size, -1, -1
                ).view(batch_size, dim, *node.output_shape[1:])
    ...

def backward_general()
    ...
    if self.infeasible_bounds is None:
        if isinstance(C, Patches):
            self.infeasible_bounds = torch.full((C.shape[1],), False, device=device)
        else:
        # If C is an Tensor/eyeC/OneHotC object, we go in to the second condition
            assert isinstance(C, (torch.Tensor, eyeC, OneHotC)), type(C)
            self.infeasible_bounds = torch.full((C.shape[0],), False, device=device)
    ...

The error is shown in the following screen shot:

Screenshot from 2024-11-09 21-59-05