logix-project / logix

AI Logging for Interpretability and Explainability🔬
Apache License 2.0
74 stars 6 forks source link

Resolve deadlock during synchronization #68

Closed sangkeun00 closed 8 months ago

sangkeun00 commented 8 months ago

In the previous code, the code got stuck randomly at torch.cuda.synchronize() in state.py. You can reproduce this random deadlock by running

python examples/mnist_influnece/compute_influences_distributed.py

Upon debugging, this is due to the use of set() for saving various states in state.py. As Python set lacks ordering among its elements, different processes tried to synchronize different states (e.g. rank 0: covariance_state, rank 1: covariance_counter), and this silently caused deadlock. I was able to find this bug with TORCH_DISTRIBUTED_DEBUG=DETAIL.

In addition, I improved __init__ and confirmed that users can now do:

import analog

analog.init(project="test")

analog.watch(model)
analog.setup({"log": "grad", "statistic": "kfac"})