Verified-Intelligence / auto_LiRPA

auto_LiRPA: An Automatic Linear Relaxation based Perturbation Analysis Library for Neural Networks and General Computational Graphs
https://arxiv.org/pdf/2002.12920
Other
269 stars 67 forks source link

verification of ensemble model - how to max multiple predictions? #72

Open luigiberducci opened 4 weeks ago

luigiberducci commented 4 weeks ago

Hello auto_LiRPA team, thanks for the great work you are doing! :)

I recently started playing with your tool but I am struggling to adapt the code for verification with MIP to my use case. In particular, I am trying to verify an ensemble model which max the prediction of many individual MLP models. For example, I would like to check that the max score from a bunch of MLP models is >0 over a certain input domain:

x* = argmin_{x} max_{m in models} m(x) > 0

To deal with the inner max, I wanted to create an ensemble model which, in the forward pass, computes the individual scores and then returns the max of them. So that I get the formulation:

x* = argmin_{x} ensemble(x) > 0

However, I run into some issues, and it would be great if you could provide feedback if I am tackling the problem in the right way. In the following, I tried to present a minimal example of what I am trying to do.

Let's consider two networks net1 and net2. As a first attempt, I created MaxNet model which uses torch.max in the forward pass.

class MaxNet(nn.Module):  
    def __init__(self, net1: nn.Module, net2: nn.Module):  
        super(MaxNet, self).__init__()  
        self.net1 = net1  
        self.net2 = net2  

    def forward(self, x):  
        ys = [self.net1(x), self.net2(x)]  
        ys = torch.cat(ys, dim=-1)          
        max_ys, __ = torch.max(ys, dim=-1)
        return max_ys

Which works as follows:

net1 = nn.Sequential(  
    nn.Linear(10, 20),  
    nn.ReLU(),  
    nn.Linear(20, 1),  
)  
net2 = nn.Sequential(  
    nn.Linear(10, 20),  
    nn.ReLU(),  
    nn.Linear(20, 1),  
)  
maxnet = MaxNet(net1, net2)

x = torch.randn(2, 10)  
maxnet(x)

>>>tensor([[0.3049], [0.2367]], grad_fn=<MaxBackward0>)

However, when I try to verify its output to be >0, I get the following error when computing the bounds:

lirpa_model = BoundedModule(model=maxnet, global_input=torch.empty(size=(1, 10)))  

xi_lower = torch.tensor([[-5.0] * 10])  
xi_upper = torch.tensor([[5.0] * 10])  
init_domain = PerturbationLpNorm(x_L=xi_lower, x_U=xi_upper)  
bounded_domain = BoundedTensor(torch.tensor([[0.0] * 10]), init_domain)  

# Call alpha-CROWN first, which gives all intermediate layer bounds.  
lb, ub = lirpa_model.compute_bounds(x=(bounded_domain,), method='alpha-CROWN')

>>> NotImplementedError: `bound_backward` for BoundReduceMax with perturbed maximumindexes is not implemented.

As a second attempt, I tried using a 1d max-pool layer:

class MaxNet(nn.Module):  
    def __init__(self, net1: nn.Module, net2: nn.Module):  
        super(MaxNet, self).__init__()  
        self.net1 = net1  
        self.net2 = net2  
        self.max_pool = nn.MaxPool1d(kernel_size=2)  # kernel size = nr models

    def forward(self, x):  
        ys = [self.net1(x), self.net2(x)]  
        ys = torch.cat(ys, dim=-1)  
        max_ys = self.max_pool(ys)          
        return max_ys

And I got the following error and my understanding is that only 2d pooling is supported:

Traceback (most recent call last):
  File "/home/luigi/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/241.14494.241/plugins/python/helpers/pydev/pydevd.py", line 1535, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/luigi/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/241.14494.241/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/luigi/Development/auto_LiRPA/examples/control/minimal_example.py", line 39, in <module>
    lirpa_model = BoundedModule(model=maxnet, global_input=torch.empty(size=(1, 10)))
  File "/home/luigi/Development/auto_LiRPA/auto_LiRPA/bound_general.py", line 130, in __init__
    self._convert(model, global_input)
  File "/home/luigi/Development/auto_LiRPA/auto_LiRPA/bound_general.py", line 848, in _convert
    nodesOP, nodesIn, nodesOut, template = self._convert_nodes(
  File "/home/luigi/Development/auto_LiRPA/auto_LiRPA/bound_general.py", line 733, in _convert_nodes
    nodesOP[n] = nodesOP[n]._replace(bound_node=op(
  File "/home/luigi/Development/auto_LiRPA/auto_LiRPA/operators/pooling.py", line 29, in __init__
    assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2])
IndexError: list index out of range

Could you please help me understand what is going on and if there are better ways to tackle this problem? Thanks a lot!

minimal_example.py.txt