Open cherrywoods opened 12 months ago
Using torch.multiprocessing
instead of multiprocessing
does not resolve the issue.
Hi @cherrywoods , I debugged a little and found it crashed when loss.backward()
in alpha-CROWN is called. I guess it's probably an issue with the multiprocessing library itself when loss.backward()
is called (IBP and CROWN doesn't have loss.backward()
).
Hi, thanks for looking into this! I also investigated whether the issue is with .backward()
, but training in a separate process works fine. Also, the following simple example does not crash the subprocess for me:
import multiprocessing as mp
import torch
from torch import nn
def worker(network, x):
x.requires_grad = True
print("Start")
output = network(x)
output.backward()
print("Finished")
print(x.grad)
if __name__ == "__main__":
network = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1))
x = torch.zeros(1, 10)
worker = mp.Process(
target=worker,
kwargs={"network": network, "x": x},
)
worker.start()
worker.join()
In the crash stack trace in the issue description that I didn't manage to reproduce yet, it confirms that the crash is during loss.backward
, but it also (more concretely) references line 107 in auto_LiRPA/operators/clampmult.py
which contains an assertion. Unfortunately, I don't really know how to debug further than the backward
call, because it calls into a C++ backend which then (apparently) calls into auto_LiRPA/operators/clampmult.py
...
I am trying to compute bounds on multiple models in parallel using the
multiprocessing
library. This works fine when using IBP or CROWN, but when usingalpha-CROWN
, I get very nondescript (fatal) errors.Reproduce
The following python script runs through, but does not print the bounds, indicating that the subprocess computing the bounds crashed silently. When I replace
"alpha-CROWN"
with"IBP"
or"CROWN"
in line 17, the code runs fine printing bounds on the console.Output:
Output with IBP:
System configuration:
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
Additional Context
In my actual project, I get an error message on the console (included below). I could not reproduce this exact error message, but I suspect the underlying issue is the same. The error might appear in my actual project because there, pytest invokes the code, because the main process and the subprocesses communicate via an
mp.SimpleQueue
, or because the subprocess obtains the bounds from a generator.