aangelopoulos / conformal-risk

Conformal prediction for controlling monotonic risk functions. Simple accompanying PyTorch code for conformal risk control in computer vision and natural language processing.
MIT License
59 stars 6 forks source link
computer-vision conformal conformal-prediction natural-language-processing python pytorch pytorch-implementation uncertainty-estimation uncertainty-quantification

Conformal Risk Control

This is the official repository of Conformal Risk Control by Anastasios N. Angelopoulos, Stephen Bates, Adam Fisch, Lihua Lei, and Tal Schuster.

Technical background

In the risk control problem, we are given some loss function $L_i(\lambda) = \ell(X_i,Y_i,\lambda)$. For example, in multi-label classification, you can think of the loss function as the false negative proportion $Li(\lambda) = 1 - \frac{|Y{i} \cap C{\lambda}(X{i})|}{|Yi|}$, where $C{\lambda}(X{i})$ is the set-valued output of a machine learning model. As $\lambda$ grows, so does the set $C{\lambda}(X{i})$, which shrinks the false negative proportion. We seek to choose $\hat{\lambda}$ based on the first $n$ data points to control the expected value of its loss on a new test point at some user-specified risk level $\alpha$, $$\mathbb{E}\big[L{n+1}(\hat{\lambda})\big] \leq \alpha.$$

The conformal risk control algorithm is in core/get_lhat.py. It is 5 lines long, including the function header.

Examples

Each of the {polyps, coco, hierarchical-imagenet, qa} folders contains a worked example of conformal risk control with a different risk function. polyps does gut polyp segmentation with false negative rate control. coco does multi-label classification with false negative rate control. hierarchical-imagenet does hierarchical classification and chooses the resolution of its prediction by bounding the graph distance to an ancestor of the true label. Finally, qa controls the F1-score in open-world question answering.

Setup

For the computer vision experiments, run

  conda env create -f environment.yml
  conda activate conformal-risk

This will install all dependencies for the vision experiments.

For the question-answering task, follow the instructions in qa/README.md.

Reproducing the experiments

After setting up the environment, enter the example folder and run the appropriate risk_histogram.py file. To produce the grids of images in the paper, run the python file containing the word grid in each folder.

Citation

@article{angelopoulos2022conformal,
  title={Conformal Risk Control},
  author={Angelopoulos, Anastasios N and Bates, Stephen and Fisch, Adam and Lei, Lihua and Schuster, Tal},
  journal={arXiv preprint arXiv:2208.02814},
  year={2022}
}