Verified-Intelligence / auto_LiRPA

auto_LiRPA: An Automatic Linear Relaxation based Perturbation Analysis Library for Neural Networks and General Computational Graphs
https://arxiv.org/pdf/2002.12920
Other
269 stars 67 forks source link

BoundedModule Error on a Convolutional Layer #51

Closed mhmd97z closed 8 months ago

mhmd97z commented 10 months ago

Hi, I got an error when I called BoundedModule() on a DNN with the following definition. I was wondering if you could help me out.

DNN:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ActorNetwork(nn.Module):
    # actornetwork pass the test
    def __init__(self,state_dim,action_dim,n_conv=128,n_fc=128,n_fc1=128):
        super(ActorNetwork,self).__init__()
        self.s_dim=state_dim
        self.a_dim=action_dim
        self.vectorOutDim=n_conv
        self.scalarOutDim=n_fc
        self.numFcInput=2 * self.vectorOutDim * (self.s_dim[1]-4+1) + 3 * self.scalarOutDim + self.vectorOutDim*(self.a_dim-4+1)
        self.numFcOutput=n_fc1

        #-------------------define layer-------------------
        self.tConv1d=nn.Conv1d(1,self.vectorOutDim,4)
        self.dConv1d=nn.Conv1d(1,self.vectorOutDim,4)
        self.cConv1d=nn.Conv1d(1,self.vectorOutDim,4)
        self.bufferFc=nn.Linear(1,self.scalarOutDim)
        self.leftChunkFc=nn.Linear(1,self.scalarOutDim)
        self.bitrateFc=nn.Linear(1,self.scalarOutDim)
        self.fullyConnected=nn.Linear(self.numFcInput,self.numFcOutput)
        self.outputLayer=nn.Linear(self.numFcOutput,self.a_dim)

        #------------------init layer weight--------------------
        nn.init.xavier_uniform_(self.bufferFc.weight.data)
        nn.init.constant_(self.bufferFc.bias.data,0.0)
        nn.init.xavier_uniform_(self.leftChunkFc.weight.data)
        nn.init.constant_(self.leftChunkFc.bias.data,0.0)
        nn.init.xavier_uniform_(self.bitrateFc.weight.data)
        nn.init.constant_(self.bitrateFc.bias.data,0.0)
        nn.init.xavier_uniform_(self.fullyConnected.weight.data)
        nn.init.constant_(self.fullyConnected.bias.data,0.0)
        nn.init.xavier_uniform_(self.tConv1d.weight.data)
        nn.init.constant_(self.tConv1d.bias.data,0.0)
        nn.init.xavier_uniform_(self.dConv1d.weight.data)
        nn.init.constant_(self.dConv1d.bias.data,0.0)
        nn.init.xavier_normal_(self.cConv1d.weight.data)
        nn.init.constant_(self.cConv1d.bias.data,0.0)

    def forward(self,inputs):
        bitrateFcOut=F.relu(self.bitrateFc(inputs[:,0:1,-1]),inplace=True)
        bufferFcOut=F.relu(self.bufferFc(inputs[:,1:2,-1]),inplace=True)
        tConv1dOut=F.relu(self.tConv1d(inputs[:,2:3,:]),inplace=True)
        dConv1dOut=F.relu(self.dConv1d(inputs[:,3:4,:]),inplace=True)
        cConv1dOut=F.relu(self.cConv1d(inputs[:,4:5,:self.a_dim]),inplace=True)
        leftChunkFcOut=F.relu(self.leftChunkFc(inputs[:,5:6,-1]),inplace=True)
        t_flatten=tConv1dOut.view(tConv1dOut.shape[0],-1)
        d_flatten=dConv1dOut.view(dConv1dOut.shape[0],-1)
        c_flatten=cConv1dOut.view(dConv1dOut.shape[0],-1)
        fullyConnectedInput=torch.cat([bitrateFcOut,bufferFcOut,t_flatten,d_flatten,c_flatten,leftChunkFcOut],1)
        fcOutput=F.relu(self.fullyConnected(fullyConnectedInput),inplace=True)
        out=torch.softmax(self.outputLayer(fcOutput),dim=-1)

        return out

Function Call:

S_INFO=6
S_LEN=8
AGENT_NUM=3
ACTION_DIM=6

a_net=ActorNetwork([S_INFO,S_LEN],ACTION_DIM)
npState=torch.randn(AGENT_NUM,S_INFO,S_LEN)
action=a_net.forward(npState)

path_critic = "/home/mzi/sys-rl-verif/pensieve-pytorch/results/actor.pt"
policy_model_state_dict = torch.load(path_critic)
a_net.load_state_dict(policy_model_state_dict)

from auto_LiRPA import BoundedModule, BoundedTensor
lirpa_model = BoundedModule(a_net, torch.empty_like(npState), device='cpu')

Error:

Cell In [9], line 2
      1 from auto_LiRPA import BoundedModule, BoundedTensor
----> 2 lirpa_model = BoundedModule(a_net, torch.empty_like(npState), device='cpu')
      4 from auto_LiRPA.perturbations import PerturbationLpNorm
      5 eps = 0.001

File ~/anaconda3/envs/ransim/lib/python3.9/site-packages/auto_LiRPA-0.3.1-py3.9.egg/auto_LiRPA/bound_general.py:103, in BoundedModule.__init__(self, model, global_input, bound_opts, device, verbose, custom_ops)
    100 self.final_shape = model(
    101     *unpack_inputs(global_input, device=self.device)).shape
    102 self.bound_opts.update({'final_shape': self.final_shape})
--> 103 self._convert(model, global_input)
    104 self._mark_perturbed_nodes()
    106 # set the default values here

File ~/anaconda3/envs/ransim/lib/python3.9/site-packages/auto_LiRPA-0.3.1-py3.9.egg/auto_LiRPA/bound_general.py:852, in BoundedModule._convert(self, model, global_input)
    849     global_input = (global_input,)
    850 self.num_global_inputs = len(global_input)
--> 852 nodesOP, nodesIn, nodesOut, template = self._convert_nodes(
    853     model, global_input)
    854 global_input = self._to(global_input, self.device)
    856 while True:

File ~/anaconda3/envs/ransim/lib/python3.9/site-packages/auto_LiRPA-0.3.1-py3.9.egg/auto_LiRPA/bound_general.py:711, in BoundedModule._convert_nodes(self, model, global_input)
    707         nodesOP[n] = nodesOP[n]._replace(bound_node=op(
    708             attr, inputs, nodesOP[n].output_index, self.bound_opts,
    709             False))
    710     else:
--> 711         nodesOP[n] = nodesOP[n]._replace(bound_node=op(
    712             attr, inputs, nodesOP[n].output_index, self.bound_opts))
    714 if unsupported_ops:
    715     logger.error('Unsupported operations:')

File ~/anaconda3/envs/ransim/lib/python3.9/site-packages/auto_LiRPA-0.3.1-py3.9.egg/auto_LiRPA/operators/convolution.py:11, in BoundConv.__init__(self, attr, inputs, output_index, options)
     10 def __init__(self, attr, inputs, output_index, options):
---> 11     assert (attr['pads'][0] == attr['pads'][2])
     12     assert (attr['pads'][1] == attr['pads'][3])
     14     super().__init__(attr, inputs, output_index, options)

IndexError: list index out of range
huanzhang12 commented 10 months ago

@mhmd97z Thank you for reporting this problem to us! We are currently preparing a new release of this library and we will fix this problem in the upcoming release (planned in September).

shizhouxing commented 8 months ago

@mhmd97z We have released a new version and the issue has been fixed.