EliasNehme / NPPC

Official implementation of the NeurIPS 2023 paper: "Uncertainty Quantification via Neural Posterior Principal Components"
7 stars 1 forks source link

Uncertainty Quantification via Neural Posterior Principal Components (NPPC)

Offical paper repository

Arxiv

This repository contains examples for training the following datasets and distortions:

The nppc folder contains the main code, while the run_*.py file contains scripts for running the examples. As NPPC requires a trained restoration model, each example has both a run_*_restoration.py file and a run_*_nppc.py file.

The minimal_example.ipynb file is a self contained notebook and can be viewed in Colab: Open In Colab

Training a restoration model

Training a restoration model can be done using the following code (see the script files for examples):

import nppc

model = nppc.RestorationModel(
    dataset=...,          # Options: "mnist" or "celeba_hq_256". A string selecting the dataset. Can be either 
    data_folder=...,      # The path to the folder containing the dataset (the folder above the dataset's folder).
    distortion_type=...,  # Options: "inpainting_1", "inpainting_2", "denoising_1", "colorization_1" or "super_resolution_1".
                          # A string selecting the distortion type.
    net_type=...,         # Options: "unet", "res_unet" or "res_cnn". A string selecting the type of network to be used.
    lr=...,               # The learning rate.
    device=...,           # A string selecting the device to be used (i.e. *cpu*, *cuda:0*, etc.).
)
trainer = nppc.RestorationTrainer(
    model=model,          #
    batch_size=...,       # The batch size.
    output_folder=...,    # The folder in which the result should be stored.
)
trainer.train(
    n_steps=...,          # The amount of update steps to perform during training.
)

The optional distortions are:

The optional networks are:

Training an NPPC model

Training an NPPC model can be done using the following code (see the script files for examples):

import nppc

model = nppc.NPPCModel(
    restoration_model_folder=...,  # The path to the folder containing the trained restoration model.
    net_type=...,                  # Options: "unet", "res_unet" or "res_cnn". A string selecting the type of network to be used.
    n_dirs=...,                    # The number of PCs to predict.
    lr=...,                        # The learning rate.
    device=...,                    # A string selecting the device to be used (i.e. *cpu*, *cuda:0*, etc.).
)

## Train
## -----
trainer = nppc.NPPCTrainer(
    model=model,
    batch_size=...,     # The batch size.
    output_folder=...,  # The folder in which the result should be stored.
)
trainer.train(
    n_steps=...,        # The amount of update steps to perform during training.
)