Verified-Intelligence / auto_LiRPA

auto_LiRPA: An Automatic Linear Relaxation based Perturbation Analysis Library for Neural Networks and General Computational Graphs
https://arxiv.org/pdf/2002.12920
Other
290 stars 75 forks source link

alpha-CROWN in multiprocessing #58

Open cherrywoods opened 12 months ago

cherrywoods commented 12 months ago

I am trying to compute bounds on multiple models in parallel using the multiprocessing library. This works fine when using IBP or CROWN, but when using alpha-CROWN, I get very nondescript (fatal) errors.

Reproduce

The following python script runs through, but does not print the bounds, indicating that the subprocess computing the bounds crashed silently. When I replace "alpha-CROWN" with "IBP" or "CROWN" in line 17, the code runs fine printing bounds on the console.

import multiprocessing as mp

import torch
from torch import nn
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm

def compute_bounds_worker(network, lb, ub):
    bounded_network = BoundedModule(
        network,
        lb,
    )
    perturbation = PerturbationLpNorm(x_L=lb, x_U=ub)
    midpoint = (ub + lb) / 2
    input_bounded = BoundedTensor(midpoint, ptb=perturbation)
    print("Compute Bounds")
    lb, ub = bounded_network.compute_bounds(x=(input_bounded,), method="alpha-CROWN")  # "IBP"
    print("Computation Finished")
    print(lb)
    print(ub)

if __name__ == "__main__":
    network = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1))
    lb = torch.zeros(1, 10)
    ub = torch.ones(1, 10)

    worker = mp.Process(
        target=compute_bounds_worker,
        kwargs={"network": network, "lb": lb, "ub": ub},
    )
    worker.start()
    worker.join()

Output:

/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
Compute Bounds

Output with IBP:

/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
Compute Bounds
Computation Finished
tensor([[-0.1279]], grad_fn=<AddBackward0>)
tensor([[0.7965]], grad_fn=<AddBackward0>)

System configuration:

Additional Context

In my actual project, I get an error message on the console (included below). I could not reproduce this exact error message, but I suspect the underlying issue is the same. The error might appear in my actual project because there, pytest invokes the code, because the main process and the subprocesses communicate via an mp.SimpleQueue, or because the subprocess obtains the bounds from a generator.

Fatal Python error: Aborted

Current thread 0x00007f019ebdd640 (most recent call first):
  <no Python frame>

Thread 0x00007f02385eb740 (most recent call first):
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/operators/clampmult.py", line 107 in backward
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/autograd/function.py", line 253 in apply
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173 in backward
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/_tensor.py", line 396 in backward
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/optimized_bounds.py", line 843 in get_optimized_bounds
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 1188 in compute_bounds
  ...
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/_pytest/config/__init__.py", line 166 in main

Extension modules: mkl._mklinit, mkl._py_mkl_service, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, scipy._lib._ccallback_c, numpy.linalg.lapack_lite, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._isolve._iterative, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.linalg._flinalg, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, torch._C, torch._C._fft, torch._C._linalg, torch._C._nn, torch._C._sparse, torch._C._special (total: 123)
cherrywoods commented 12 months ago

Using torch.multiprocessing instead of multiprocessing does not resolve the issue.

shizhouxing commented 11 months ago

Hi @cherrywoods , I debugged a little and found it crashed when loss.backward() in alpha-CROWN is called. I guess it's probably an issue with the multiprocessing library itself when loss.backward() is called (IBP and CROWN doesn't have loss.backward()).

cherrywoods commented 11 months ago

Hi, thanks for looking into this! I also investigated whether the issue is with .backward(), but training in a separate process works fine. Also, the following simple example does not crash the subprocess for me:

import multiprocessing as mp

import torch
from torch import nn

def worker(network, x):
    x.requires_grad = True
    print("Start")
    output = network(x)
    output.backward()
    print("Finished")
    print(x.grad)

if __name__ == "__main__":
    network = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1))
    x = torch.zeros(1, 10)

    worker = mp.Process(
        target=worker,
        kwargs={"network": network, "x": x},
    )
    worker.start()
    worker.join()
cherrywoods commented 11 months ago

In the crash stack trace in the issue description that I didn't manage to reproduce yet, it confirms that the crash is during loss.backward, but it also (more concretely) references line 107 in auto_LiRPA/operators/clampmult.py which contains an assertion. Unfortunately, I don't really know how to debug further than the backward call, because it calls into a C++ backend which then (apparently) calls into auto_LiRPA/operators/clampmult.py...