Offical paper repository
This repository contains examples for training the following datasets and distortions:
The nppc folder contains the main code, while the run_*.py file contains scripts for running the examples. As NPPC requires a trained restoration model, each example has both a run_*_restoration.py file and a run_*_nppc.py file.
The minimal_example.ipynb file is a self contained notebook and can be viewed in Colab:
Training a restoration model can be done using the following code (see the script files for examples):
import nppc
model = nppc.RestorationModel(
dataset=..., # Options: "mnist" or "celeba_hq_256". A string selecting the dataset. Can be either
data_folder=..., # The path to the folder containing the dataset (the folder above the dataset's folder).
distortion_type=..., # Options: "inpainting_1", "inpainting_2", "denoising_1", "colorization_1" or "super_resolution_1".
# A string selecting the distortion type.
net_type=..., # Options: "unet", "res_unet" or "res_cnn". A string selecting the type of network to be used.
lr=..., # The learning rate.
device=..., # A string selecting the device to be used (i.e. *cpu*, *cuda:0*, etc.).
)
trainer = nppc.RestorationTrainer(
model=model, #
batch_size=..., # The batch size.
output_folder=..., # The folder in which the result should be stored.
)
trainer.train(
n_steps=..., # The amount of update steps to perform during training.
)
The optional distortions are:
The optional networks are:
Training an NPPC model can be done using the following code (see the script files for examples):
import nppc
model = nppc.NPPCModel(
restoration_model_folder=..., # The path to the folder containing the trained restoration model.
net_type=..., # Options: "unet", "res_unet" or "res_cnn". A string selecting the type of network to be used.
n_dirs=..., # The number of PCs to predict.
lr=..., # The learning rate.
device=..., # A string selecting the device to be used (i.e. *cpu*, *cuda:0*, etc.).
)
## Train
## -----
trainer = nppc.NPPCTrainer(
model=model,
batch_size=..., # The batch size.
output_folder=..., # The folder in which the result should be stored.
)
trainer.train(
n_steps=..., # The amount of update steps to perform during training.
)