Verified-Intelligence / auto_LiRPA

auto_LiRPA: An Automatic Linear Relaxation based Perturbation Analysis Library for Neural Networks and General Computational Graphs
https://arxiv.org/pdf/2002.12920
Other
282 stars 69 forks source link

ConvTranspose CROWN Bounds #61

Open cherrywoods opened 9 months ago

cherrywoods commented 9 months ago

Describe the bug I was delighted to see that auto_LiRPA can bound ConvTranspose layers out of the box, but, unfortunately, CROWN in batch mode doesn't seem to work.

To Reproduce Code to reproduce with the attached network (zipped): mnist_conv_generator.zip

>>> import torch
>>> from auto_LiRPA import PerturbationLpNorm, BoundedModule, BoundedTensor
/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
No CUDA runtime is found, using CUDA_HOME='/usr'
>>> net = torch.load("mnist_conv_generator.pyt")
>>> net = BoundedModule(net, torch.zeros(1, 4, 1, 1))
/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/torch/nn/functional.py:2403: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if size_prods == 1:
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '6'. This changes graph semantics.
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '12'. This changes graph semantics.
>>> ptb = PerturbationLpNorm(x_L=torch.zeros(1, 4, 1, 1), x_U=torch.ones(1, 4, 1, 1))
>>> tensor = BoundedTensor(torch.zeros(10, 4, 1, 1), ptb)
>>> net.compute_bounds(x=(tensor,), method="ibp")  # works fine, output omitted
>>> net.compute_bounds(x=(tensor,), method="crown")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 1206, in compute_bounds
    return self._compute_bounds_main(C=C,
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 1303, in _compute_bounds_main
    self.check_prior_bounds(final)
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  [Previous line repeated 2 more times]
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 804, in check_prior_bounds
    self.compute_intermediate_bounds(
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 915, in compute_intermediate_bounds
    self.restore_sparse_bounds(
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/backward_bound.py", line 575, in restore_sparse_bounds
    lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_lower
RuntimeError: shape mismatch: value tensor of shape [10, 703] cannot be broadcast to indexing result of shape [1, 703]

System configuration:

cherrywoods commented 9 months ago

CROWN with a single input (e.g. torch.zeros(1, 4, 1, 1) instead of torch.zeros(10, 4, 1, 1)) works fine.

shizhouxing commented 9 months ago

Hi @cherrywoods , could you please share the code for the model definition?

cherrywoods commented 9 months ago

Sure, sorry for not including it right away:

generator = nn.Sequential(
    nn.ConvTranspose2d(4, 49, kernel_size=4, stride=1, bias=False),  # 49 x 4 x 4
    nn.BatchNorm2d(49, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(49, 12, kernel_size=4, stride=4, bias=False),  # 12 x 16 x 16
    nn.BatchNorm2d(12, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(12, 1, kernel_size=13, stride=1, bias=False),  # 1 x 28 x 28
    nn.Sigmoid(),
)
shizhouxing commented 9 months ago

Hi @cherrywoods , the issue is that you need to update ptb as well to use a batch size of 10.

cherrywoods commented 9 months ago

Hi @shizhouxing, the incorrect batch dimension was indeed a problem in the code I posted, however a very similar error persists also with fixed batch dimensions:

import torch
from auto_LiRPA import PerturbationLpNorm, BoundedModule, BoundedTensor
net = torch.load("mnist_conv_generator.pyt")
net = BoundedModule(net, torch.zeros(1, 4, 1, 1))
ptb = PerturbationLpNorm(x_L=torch.zeros(10, 4, 1, 1), x_U=torch.ones(10, 4, 1, 1))
tensor = BoundedTensor(torch.zeros(10, 4, 1, 1), ptb)
net.compute_bounds(x=(tensor,), method="ibp")  # works fine, output omitted
net.compute_bounds(x=(tensor,), method="crown")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 1339, in compute_bounds
    self.check_prior_bounds(final)
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 883, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 883, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 883, in check_prior_bounds
    self.check_prior_bounds(n)
  [Previous line repeated 2 more times]
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 885, in check_prior_bounds
    self.compute_intermediate_bounds(
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 983, in compute_intermediate_bounds
    node.lower, node.upper = self.backward_general(
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/backward_bound.py", line 212, in backward_general
    lb, ub = concretize(
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/backward_bound.py", line 532, in concretize
    lb = lb + root[i].perturbation.concretize(
RuntimeError: The size of tensor a (4) must match the size of tensor b (784) at non-singleton dimension 3
shizhouxing commented 9 months ago

Hi @cherrywoods ,

You'll need to modify both x_L and x_U:

ptb = PerturbationLpNorm(x_L=torch.zeros(10, 4, 1, 1), x_U=torch.ones(1, 4, 1, 1))
cherrywoods commented 9 months ago

Hi @shizhouxing, this was only a typo. I updated the code above. The error remains the same.

shizhouxing commented 9 months ago

Hi @cherrywoods , but I tried your code and it worked fine on my side.

I see your output contains auto_LiRPA-0.3.1. Are you using the latest version of auto_LiRPA? The latest version should have a version number of 0.4.

cherrywoods commented 9 months ago

That indeed seemed to be the issue. I somehow messed up pulling the latest release from Github. Thanks for your patience and sorry for the inconvenience. I'm happy that I can now use ConvTranspose layers :)

cherrywoods commented 9 months ago

I reopen this because I keep getting errors in the actual code I'm using, which obviously uses different bounds than 0.0 and 1.0. I debugged through this for the past hour and couldn't find anything like the errors that we discussed above. To be on the safe side this time, I made a docker container that reproduces the issue: conv_transpose_issue.zip

The container creates a conda environment, downloads and installs the latest auto_LiRPA commit and then runs the following script:

import torch
from torch import nn
import auto_LiRPA
from auto_LiRPA import PerturbationLpNorm, BoundedModule, BoundedTensor

print(auto_LiRPA.__version__)

torch.manual_seed(0)
net = nn.Sequential(
    nn.ConvTranspose2d(4, 49, kernel_size=4, stride=1, bias=False),  # 49 x 4 x 4
    nn.BatchNorm2d(49, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(49, 12, kernel_size=4, stride=4, bias=False),  # 12 x 16 x 16
    nn.BatchNorm2d(12, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(12, 1, kernel_size=13, stride=1, bias=False),  # 1 x 28 x 28
    nn.Sigmoid(),
)
net = BoundedModule(net, torch.empty(1, 4, 1, 1))

lb = torch.zeros(1, 4, 1, 1)
ub = torch.ones(1, 4, 1, 1)
ptb = PerturbationLpNorm(x_L=lb,x_U=ub)
tensor = BoundedTensor(lb, ptb)
print(lb.shape, ub.shape, tensor.shape)
print(lb, ub, tensor)
bounds = net.compute_bounds(x=(tensor,), method="crown")  # works fine
print(bounds)

lb = lb.clone() - 1.0
ptb = PerturbationLpNorm(x_L=lb,x_U=ub)
tensor = BoundedTensor(lb, ptb)
print(lb.shape, ub.shape, tensor.shape)
print(lb, ub, tensor)
bounds = net.compute_bounds(x=(tensor,), method="crown")  # fails
print(bounds)

When I run this using:

docker build . -t auto_lirpa
docker run -t auto_lirpa

I get this output:

/opt/conda/envs/auto_LiRPA/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
0.4.0
/opt/conda/envs/auto_LiRPA/lib/python3.10/site-packages/torch/nn/functional.py:2403: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if size_prods == 1:
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '6'. This changes graph semantics.
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '12'. This changes graph semantics.
torch.Size([1, 4, 1, 1]) torch.Size([1, 4, 1, 1]) (1, 4, 1, 1)
tensor([[[[0.]],

         [[0.]],

         [[0.]],

         [[0.]]]]) tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]) <BoundedTensor: BoundedTensor([[[[0.]],

                [[0.]],

                [[0.]],

                [[0.]]]]), PerturbationLpNorm(norm=inf, eps=0, x_L=tensor([[[[0.]],

         [[0.]],

         [[0.]],

         [[0.]]]]), x_U=tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]))>
(tensor([[[[0.4989, 0.4986, 0.4980, 0.4974, 0.4972, 0.4965, 0.4962, 0.4952,
           0.4960, 0.4955, 0.4947, 0.4937, 0.4937, 0.4942, 0.4927, 0.4931,
           0.4949, 0.4950, 0.4949, 0.4951, 0.4974, 0.4967, 0.4970, 0.4978,
           0.4982, 0.4984, 0.4992, 0.4991],
          [0.4988, 0.4978, 0.4970, 0.4960, 0.4946, 0.4945, 0.4925, 0.4929,
           0.4919, 0.4920, 0.4908, 0.4898, 0.4899, 0.4897, 0.4895, 0.4895,
           0.4908, 0.4920, 0.4924, 0.4935, 0.4945, 0.4948, 0.4960, 0.4963,
           0.4970, 0.4968, 0.4985, 0.4991],
          [0.4979, 0.4969, 0.4955, 0.4937, 0.4921, 0.4922, 0.4906, 0.4904,
           0.4886, 0.4863, 0.4868, 0.4859, 0.4847, 0.4852, 0.4834, 0.4849,
           0.4853, 0.4862, 0.4894, 0.4895, 0.4915, 0.4921, 0.4932, 0.4933,
           0.4948, 0.4965, 0.4969, 0.4981],
          [0.4980, 0.4958, 0.4950, 0.4936, 0.4930, 0.4902, 0.4890, 0.4883,
           0.4864, 0.4847, 0.4837, 0.4822, 0.4816, 0.4810, 0.4812, 0.4808,
           0.4838, 0.4839, 0.4856, 0.4872, 0.4884, 0.4903, 0.4919, 0.4926,
           0.4932, 0.4949, 0.4968, 0.4986],
          [0.4972, 0.4944, 0.4926, 0.4917, 0.4899, 0.4879, 0.4849, 0.4841,
           0.4831, 0.4800, 0.4791, 0.4771, 0.4773, 0.4755, 0.4749, 0.4753,
           0.4782, 0.4796, 0.4813, 0.4834, 0.4866, 0.4876, 0.4888, 0.4905,
           0.4929, 0.4935, 0.4953, 0.4978],
          [0.4966, 0.4946, 0.4924, 0.4890, 0.4878, 0.4853, 0.4827, 0.4808,
           0.4784, 0.4778, 0.4756, 0.4726, 0.4713, 0.4724, 0.4703, 0.4710,
           0.4733, 0.4761, 0.4784, 0.4807, 0.4831, 0.4860, 0.4872, 0.4885,
           0.4904, 0.4931, 0.4945, 0.4974],
          [0.4967, 0.4941, 0.4914, 0.4891, 0.4847, 0.4837, 0.4812, 0.4793,
           0.4758, 0.4712, 0.4710, 0.4680, 0.4634, 0.4664, 0.4658, 0.4656,
           0.4682, 0.4713, 0.4742, 0.4776, 0.4787, 0.4813, 0.4848, 0.4868,
           0.4888, 0.4919, 0.4946, 0.4970],
          [0.4957, 0.4937, 0.4898, 0.4873, 0.4840, 0.4801, 0.4780, 0.4752,
           0.4733, 0.4695, 0.4679, 0.4649, 0.4632, 0.4615, 0.4618, 0.4622,
           0.4654, 0.4679, 0.4705, 0.4751, 0.4779, 0.4804, 0.4826, 0.4848,
           0.4868, 0.4911, 0.4937, 0.4970],
          [0.4960, 0.4917, 0.4870, 0.4846, 0.4805, 0.4777, 0.4753, 0.4731,
           0.4690, 0.4654, 0.4628, 0.4602, 0.4559, 0.4551, 0.4559, 0.4562,
           0.4601, 0.4631, 0.4674, 0.4710, 0.4756, 0.4768, 0.4803, 0.4826,
           0.4871, 0.4893, 0.4928, 0.4963],
          [0.4952, 0.4918, 0.4874, 0.4831, 0.4779, 0.4769, 0.4736, 0.4697,
           0.4658, 0.4622, 0.4592, 0.4563, 0.4500, 0.4524, 0.4527, 0.4506,
           0.4557, 0.4603, 0.4641, 0.4678, 0.4725, 0.4752, 0.4782, 0.4809,
           0.4847, 0.4887, 0.4923, 0.4958],
          [0.4953, 0.4901, 0.4872, 0.4829, 0.4787, 0.4749, 0.4712, 0.4677,
           0.4632, 0.4582, 0.4555, 0.4524, 0.4484, 0.4472, 0.4479, 0.4479,
           0.4521, 0.4565, 0.4608, 0.4645, 0.4686, 0.4728, 0.4759, 0.4796,
           0.4830, 0.4863, 0.4917, 0.4955],
          [0.4949, 0.4898, 0.4857, 0.4820, 0.4769, 0.4712, 0.4692, 0.4641,
           0.4607, 0.4569, 0.4527, 0.4478, 0.4450, 0.4428, 0.4435, 0.4430,
           0.4489, 0.4526, 0.4569, 0.4615, 0.4675, 0.4697, 0.4744, 0.4778,
           0.4819, 0.4843, 0.4912, 0.4953],
          [0.4943, 0.4882, 0.4841, 0.4808, 0.4739, 0.4696, 0.4659, 0.4614,
           0.4552, 0.4526, 0.4486, 0.4449, 0.4390, 0.4368, 0.4373, 0.4373,
           0.4441, 0.4469, 0.4523, 0.4572, 0.4636, 0.4666, 0.4704, 0.4750,
           0.4804, 0.4850, 0.4884, 0.4928],
          [0.4936, 0.4884, 0.4845, 0.4803, 0.4742, 0.4681, 0.4656, 0.4614,
           0.4557, 0.4509, 0.4485, 0.4396, 0.4384, 0.4378, 0.4364, 0.4367,
           0.4424, 0.4479, 0.4525, 0.4561, 0.4633, 0.4677, 0.4707, 0.4752,
           0.4794, 0.4846, 0.4898, 0.4944],
          [0.4935, 0.4893, 0.4848, 0.4804, 0.4746, 0.4701, 0.4669, 0.4623,
           0.4571, 0.4509, 0.4480, 0.4437, 0.4388, 0.4381, 0.4387, 0.4384,
           0.4429, 0.4482, 0.4523, 0.4569, 0.4616, 0.4673, 0.4707, 0.4750,
           0.4791, 0.4847, 0.4894, 0.4941],
          [0.4942, 0.4888, 0.4855, 0.4810, 0.4756, 0.4702, 0.4668, 0.4617,
           0.4572, 0.4532, 0.4489, 0.4444, 0.4398, 0.4388, 0.4389, 0.4361,
           0.4439, 0.4489, 0.4535, 0.4575, 0.4627, 0.4667, 0.4706, 0.4753,
           0.4802, 0.4845, 0.4899, 0.4947],
          [0.4953, 0.4903, 0.4865, 0.4830, 0.4786, 0.4732, 0.4701, 0.4660,
           0.4613, 0.4570, 0.4531, 0.4504, 0.4453, 0.4426, 0.4439, 0.4434,
           0.4478, 0.4525, 0.4580, 0.4615, 0.4663, 0.4693, 0.4735, 0.4774,
           0.4829, 0.4848, 0.4893, 0.4947],
          [0.4947, 0.4915, 0.4872, 0.4836, 0.4781, 0.4752, 0.4727, 0.4680,
           0.4640, 0.4603, 0.4577, 0.4527, 0.4488, 0.4470, 0.4481, 0.4476,
           0.4520, 0.4570, 0.4598, 0.4624, 0.4687, 0.4718, 0.4731, 0.4789,
           0.4813, 0.4874, 0.4905, 0.4952],
          [0.4964, 0.4926, 0.4888, 0.4856, 0.4815, 0.4786, 0.4748, 0.4711,
           0.4684, 0.4646, 0.4617, 0.4579, 0.4537, 0.4531, 0.4533, 0.4528,
           0.4566, 0.4604, 0.4632, 0.4673, 0.4712, 0.4747, 0.4776, 0.4811,
           0.4844, 0.4883, 0.4918, 0.4959],
          [0.4962, 0.4923, 0.4897, 0.4862, 0.4825, 0.4803, 0.4776, 0.4715,
           0.4707, 0.4683, 0.4652, 0.4618, 0.4579, 0.4565, 0.4573, 0.4566,
           0.4620, 0.4640, 0.4677, 0.4699, 0.4749, 0.4756, 0.4799, 0.4827,
           0.4857, 0.4893, 0.4933, 0.4962],
          [0.4968, 0.4936, 0.4910, 0.4885, 0.4847, 0.4817, 0.4796, 0.4777,
           0.4731, 0.4713, 0.4695, 0.4670, 0.4624, 0.4620, 0.4623, 0.4621,
           0.4650, 0.4669, 0.4713, 0.4743, 0.4754, 0.4798, 0.4820, 0.4828,
           0.4884, 0.4905, 0.4932, 0.4962],
          [0.4958, 0.4934, 0.4917, 0.4896, 0.4863, 0.4839, 0.4809, 0.4795,
           0.4761, 0.4750, 0.4724, 0.4700, 0.4679, 0.4656, 0.4665, 0.4649,
           0.4689, 0.4707, 0.4747, 0.4770, 0.4801, 0.4806, 0.4831, 0.4871,
           0.4893, 0.4920, 0.4941, 0.4966],
          [0.4977, 0.4953, 0.4931, 0.4912, 0.4893, 0.4868, 0.4846, 0.4821,
           0.4806, 0.4787, 0.4760, 0.4750, 0.4719, 0.4721, 0.4710, 0.4709,
           0.4736, 0.4759, 0.4771, 0.4793, 0.4815, 0.4827, 0.4859, 0.4884,
           0.4891, 0.4919, 0.4948, 0.4975],
          [0.4976, 0.4958, 0.4939, 0.4926, 0.4904, 0.4893, 0.4868, 0.4849,
           0.4830, 0.4822, 0.4804, 0.4789, 0.4762, 0.4753, 0.4757, 0.4758,
           0.4772, 0.4802, 0.4815, 0.4820, 0.4846, 0.4864, 0.4883, 0.4892,
           0.4917, 0.4933, 0.4946, 0.4975],
          [0.4978, 0.4969, 0.4950, 0.4953, 0.4937, 0.4911, 0.4896, 0.4891,
           0.4872, 0.4861, 0.4852, 0.4842, 0.4821, 0.4798, 0.4817, 0.4815,
           0.4834, 0.4844, 0.4856, 0.4867, 0.4868, 0.4890, 0.4906, 0.4925,
           0.4933, 0.4941, 0.4954, 0.4980],
          [0.4985, 0.4967, 0.4963, 0.4960, 0.4948, 0.4932, 0.4925, 0.4916,
           0.4899, 0.4890, 0.4877, 0.4865, 0.4855, 0.4851, 0.4836, 0.4860,
           0.4866, 0.4872, 0.4881, 0.4892, 0.4910, 0.4901, 0.4919, 0.4943,
           0.4953, 0.4961, 0.4972, 0.4987],
          [0.4988, 0.4982, 0.4975, 0.4968, 0.4963, 0.4956, 0.4941, 0.4936,
           0.4926, 0.4925, 0.4916, 0.4900, 0.4899, 0.4899, 0.4894, 0.4887,
           0.4900, 0.4909, 0.4917, 0.4924, 0.4931, 0.4939, 0.4946, 0.4958,
           0.4970, 0.4976, 0.4977, 0.4990],
          [0.4992, 0.4991, 0.4987, 0.4983, 0.4977, 0.4972, 0.4969, 0.4960,
           0.4961, 0.4953, 0.4954, 0.4944, 0.4944, 0.4943, 0.4938, 0.4942,
           0.4943, 0.4951, 0.4955, 0.4958, 0.4961, 0.4968, 0.4970, 0.4980,
           0.4977, 0.4986, 0.4986, 0.4992]]]], grad_fn=<ViewBackward0>), tensor([[[[0.5008, 0.5009, 0.5015, 0.5019, 0.5022, 0.5023, 0.5029, 0.5033,
           0.5039, 0.5049, 0.5048, 0.5050, 0.5060, 0.5064, 0.5057, 0.5054,
           0.5054, 0.5060, 0.5048, 0.5048, 0.5041, 0.5039, 0.5032, 0.5027,
           0.5020, 0.5027, 0.5015, 0.5011],
          [0.5012, 0.5017, 0.5018, 0.5030, 0.5038, 0.5050, 0.5054, 0.5062,
           0.5070, 0.5081, 0.5082, 0.5088, 0.5104, 0.5102, 0.5110, 0.5100,
           0.5097, 0.5092, 0.5081, 0.5083, 0.5064, 0.5058, 0.5055, 0.5044,
           0.5032, 0.5038, 0.5023, 0.5016],
          [0.5014, 0.5027, 0.5038, 0.5048, 0.5065, 0.5078, 0.5094, 0.5097,
           0.5095, 0.5121, 0.5125, 0.5132, 0.5140, 0.5152, 0.5156, 0.5144,
           0.5125, 0.5133, 0.5128, 0.5106, 0.5088, 0.5090, 0.5088, 0.5059,
           0.5050, 0.5047, 0.5031, 0.5016],
          [0.5017, 0.5035, 0.5040, 0.5057, 0.5077, 0.5097, 0.5103, 0.5125,
           0.5140, 0.5142, 0.5167, 0.5178, 0.5195, 0.5193, 0.5199, 0.5190,
           0.5177, 0.5162, 0.5152, 0.5145, 0.5117, 0.5104, 0.5099, 0.5077,
           0.5078, 0.5060, 0.5047, 0.5025],
          [0.5022, 0.5048, 0.5058, 0.5077, 0.5104, 0.5123, 0.5137, 0.5166,
           0.5190, 0.5201, 0.5216, 0.5229, 0.5264, 0.5275, 0.5257, 0.5255,
           0.5246, 0.5216, 0.5203, 0.5198, 0.5159, 0.5136, 0.5132, 0.5105,
           0.5086, 0.5076, 0.5061, 0.5031],
          [0.5031, 0.5056, 0.5074, 0.5090, 0.5117, 0.5158, 0.5184, 0.5193,
           0.5217, 0.5267, 0.5264, 0.5268, 0.5311, 0.5323, 0.5299, 0.5301,
           0.5289, 0.5286, 0.5241, 0.5220, 0.5192, 0.5171, 0.5147, 0.5123,
           0.5096, 0.5081, 0.5060, 0.5034],
          [0.5037, 0.5062, 0.5086, 0.5097, 0.5139, 0.5167, 0.5196, 0.5223,
           0.5252, 0.5300, 0.5293, 0.5315, 0.5346, 0.5349, 0.5373, 0.5342,
           0.5324, 0.5310, 0.5278, 0.5252, 0.5217, 0.5184, 0.5178, 0.5136,
           0.5110, 0.5099, 0.5070, 0.5042],
          [0.5035, 0.5070, 0.5091, 0.5117, 0.5166, 0.5184, 0.5234, 0.5264,
           0.5279, 0.5297, 0.5338, 0.5364, 0.5399, 0.5396, 0.5395, 0.5409,
           0.5361, 0.5323, 0.5319, 0.5282, 0.5258, 0.5209, 0.5202, 0.5160,
           0.5129, 0.5113, 0.5084, 0.5040],
          [0.5048, 0.5074, 0.5100, 0.5134, 0.5172, 0.5210, 0.5242, 0.5290,
           0.5319, 0.5347, 0.5379, 0.5411, 0.5484, 0.5445, 0.5443, 0.5442,
           0.5407, 0.5373, 0.5344, 0.5312, 0.5274, 0.5259, 0.5216, 0.5177,
           0.5144, 0.5113, 0.5084, 0.5048],
          [0.5051, 0.5083, 0.5110, 0.5138, 0.5184, 0.5240, 0.5285, 0.5298,
           0.5342, 0.5385, 0.5417, 0.5433, 0.5484, 0.5491, 0.5492, 0.5486,
           0.5444, 0.5427, 0.5379, 0.5337, 0.5312, 0.5264, 0.5238, 0.5191,
           0.5152, 0.5119, 0.5092, 0.5052],
          [0.5049, 0.5102, 0.5127, 0.5153, 0.5213, 0.5271, 0.5303, 0.5329,
           0.5376, 0.5417, 0.5439, 0.5466, 0.5519, 0.5542, 0.5535, 0.5526,
           0.5479, 0.5453, 0.5408, 0.5373, 0.5316, 0.5288, 0.5266, 0.5206,
           0.5163, 0.5134, 0.5096, 0.5057],
          [0.5054, 0.5102, 0.5140, 0.5174, 0.5236, 0.5275, 0.5317, 0.5359,
           0.5408, 0.5448, 0.5471, 0.5524, 0.5584, 0.5583, 0.5577, 0.5592,
           0.5541, 0.5509, 0.5453, 0.5430, 0.5363, 0.5312, 0.5272, 0.5234,
           0.5203, 0.5151, 0.5116, 0.5063],
          [0.5067, 0.5103, 0.5155, 0.5196, 0.5271, 0.5296, 0.5346, 0.5396,
           0.5461, 0.5484, 0.5517, 0.5568, 0.5640, 0.5631, 0.5642, 0.5643,
           0.5596, 0.5536, 0.5501, 0.5447, 0.5400, 0.5331, 0.5298, 0.5257,
           0.5204, 0.5171, 0.5121, 0.5067],
          [0.5055, 0.5113, 0.5148, 0.5190, 0.5242, 0.5304, 0.5354, 0.5395,
           0.5440, 0.5492, 0.5527, 0.5561, 0.5630, 0.5643, 0.5652, 0.5627,
           0.5584, 0.5548, 0.5523, 0.5439, 0.5392, 0.5345, 0.5302, 0.5252,
           0.5214, 0.5163, 0.5128, 0.5067],
          [0.5056, 0.5112, 0.5151, 0.5184, 0.5250, 0.5296, 0.5356, 0.5383,
           0.5453, 0.5491, 0.5522, 0.5553, 0.5627, 0.5649, 0.5649, 0.5638,
           0.5582, 0.5543, 0.5506, 0.5447, 0.5394, 0.5363, 0.5302, 0.5254,
           0.5195, 0.5179, 0.5122, 0.5066],
          [0.5060, 0.5101, 0.5151, 0.5191, 0.5254, 0.5296, 0.5344, 0.5388,
           0.5439, 0.5480, 0.5514, 0.5564, 0.5628, 0.5630, 0.5642, 0.5639,
           0.5576, 0.5549, 0.5503, 0.5453, 0.5392, 0.5346, 0.5319, 0.5258,
           0.5210, 0.5188, 0.5124, 0.5068],
          [0.5052, 0.5097, 0.5139, 0.5177, 0.5237, 0.5268, 0.5314, 0.5356,
           0.5405, 0.5440, 0.5483, 0.5514, 0.5571, 0.5567, 0.5577, 0.5580,
           0.5549, 0.5527, 0.5456, 0.5406, 0.5379, 0.5310, 0.5297, 0.5229,
           0.5186, 0.5140, 0.5110, 0.5064],
          [0.5048, 0.5090, 0.5137, 0.5158, 0.5208, 0.5258, 0.5305, 0.5320,
           0.5367, 0.5403, 0.5451, 0.5461, 0.5528, 0.5526, 0.5539, 0.5523,
           0.5490, 0.5463, 0.5413, 0.5360, 0.5328, 0.5284, 0.5261, 0.5204,
           0.5178, 0.5127, 0.5097, 0.5052],
          [0.5051, 0.5088, 0.5117, 0.5147, 0.5196, 0.5231, 0.5262, 0.5285,
           0.5342, 0.5365, 0.5397, 0.5421, 0.5477, 0.5485, 0.5488, 0.5478,
           0.5439, 0.5416, 0.5382, 0.5338, 0.5292, 0.5264, 0.5230, 0.5191,
           0.5145, 0.5125, 0.5099, 0.5059],
          [0.5042, 0.5076, 0.5126, 0.5132, 0.5172, 0.5212, 0.5247, 0.5264,
           0.5315, 0.5334, 0.5369, 0.5389, 0.5433, 0.5455, 0.5435, 0.5436,
           0.5411, 0.5383, 0.5338, 0.5313, 0.5279, 0.5237, 0.5209, 0.5200,
           0.5139, 0.5118, 0.5086, 0.5043],
          [0.5037, 0.5067, 0.5093, 0.5124, 0.5148, 0.5172, 0.5201, 0.5238,
           0.5262, 0.5288, 0.5306, 0.5330, 0.5368, 0.5372, 0.5377, 0.5376,
           0.5347, 0.5320, 0.5293, 0.5260, 0.5234, 0.5203, 0.5174, 0.5148,
           0.5120, 0.5094, 0.5076, 0.5041],
          [0.5027, 0.5052, 0.5080, 0.5102, 0.5125, 0.5161, 0.5179, 0.5199,
           0.5224, 0.5248, 0.5270, 0.5292, 0.5329, 0.5347, 0.5345, 0.5344,
           0.5304, 0.5298, 0.5261, 0.5229, 0.5209, 0.5181, 0.5161, 0.5136,
           0.5106, 0.5088, 0.5087, 0.5038],
          [0.5031, 0.5046, 0.5079, 0.5091, 0.5125, 0.5146, 0.5158, 0.5173,
           0.5218, 0.5223, 0.5230, 0.5250, 0.5288, 0.5297, 0.5306, 0.5282,
           0.5266, 0.5265, 0.5221, 0.5196, 0.5180, 0.5169, 0.5140, 0.5119,
           0.5096, 0.5074, 0.5060, 0.5030],
          [0.5029, 0.5038, 0.5061, 0.5073, 0.5121, 0.5134, 0.5130, 0.5144,
           0.5175, 0.5185, 0.5191, 0.5212, 0.5246, 0.5248, 0.5243, 0.5244,
           0.5223, 0.5217, 0.5184, 0.5161, 0.5147, 0.5134, 0.5111, 0.5100,
           0.5084, 0.5060, 0.5060, 0.5032],
          [0.5019, 0.5037, 0.5051, 0.5063, 0.5087, 0.5087, 0.5109, 0.5114,
           0.5141, 0.5142, 0.5156, 0.5170, 0.5196, 0.5188, 0.5198, 0.5195,
           0.5183, 0.5175, 0.5159, 0.5144, 0.5119, 0.5104, 0.5086, 0.5080,
           0.5072, 0.5054, 0.5046, 0.5020],
          [0.5018, 0.5027, 0.5041, 0.5051, 0.5063, 0.5084, 0.5088, 0.5093,
           0.5109, 0.5128, 0.5127, 0.5130, 0.5145, 0.5147, 0.5154, 0.5151,
           0.5136, 0.5127, 0.5113, 0.5107, 0.5089, 0.5084, 0.5068, 0.5058,
           0.5056, 0.5044, 0.5033, 0.5021],
          [0.5008, 0.5020, 0.5032, 0.5035, 0.5045, 0.5055, 0.5059, 0.5066,
           0.5074, 0.5080, 0.5085, 0.5091, 0.5111, 0.5107, 0.5111, 0.5108,
           0.5108, 0.5089, 0.5081, 0.5079, 0.5073, 0.5067, 0.5051, 0.5046,
           0.5040, 0.5039, 0.5025, 0.5018],
          [0.5005, 0.5009, 0.5013, 0.5027, 0.5026, 0.5030, 0.5038, 0.5034,
           0.5044, 0.5042, 0.5045, 0.5048, 0.5058, 0.5061, 0.5065, 0.5062,
           0.5062, 0.5051, 0.5048, 0.5043, 0.5036, 0.5037, 0.5028, 0.5024,
           0.5022, 0.5019, 0.5013, 0.5014]]]], grad_fn=<ViewBackward0>))
torch.Size([1, 4, 1, 1]) torch.Size([1, 4, 1, 1]) (1, 4, 1, 1)
tensor([[[[-1.]],

         [[-1.]],

         [[-1.]],

         [[-1.]]]]) tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]) <BoundedTensor: BoundedTensor([[[[-1.]],

                [[-1.]],

                [[-1.]],

                [[-1.]]]]), PerturbationLpNorm(norm=inf, eps=0, x_L=tensor([[[[-1.]],

         [[-1.]],

         [[-1.]],

         [[-1.]]]]), x_U=tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]))>
Traceback (most recent call last):
  File "/auto_LiRPA/script.py", line 35, in <module>
    bounds = net.compute_bounds(x=(tensor,), method="crown")  # fails
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 1206, in compute_bounds
    return self._compute_bounds_main(C=C,
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 1303, in _compute_bounds_main
    self.check_prior_bounds(final)
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  [Previous line repeated 2 more times]
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 804, in check_prior_bounds
    self.compute_intermediate_bounds(
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 910, in compute_intermediate_bounds
    node.lower, node.upper = self.backward_general(
  File "/auto_LiRPA/auto_LiRPA/backward_bound.py", line 324, in backward_general
    lb, ub = concretize(self, batch_size, output_dim, lb, ub,
  File "/auto_LiRPA/auto_LiRPA/backward_bound.py", line 684, in concretize
    lb = lb + roots[i].perturbation.concretize(
RuntimeError: The size of tensor a (4) must match the size of tensor b (784) at non-singleton dimension 3
ERROR conda.cli.main_run:execute(49): `conda run python script.py` failed. (See above for error)

I know this behaviour is extremely strange, but since I am only subtracting 1.0 from the lower bound for which CROWN works, I don't think it's a shape issue again.

cherrywoods commented 9 months ago

I also confirmed that the error persists when I use a batch size of 10 for lb and ub.