jaxngp
This repository contains JAX implementations of:
- a multiresolution hash encoder (JAX)
- an accelerated volume renderer for fast training of NeRFs (CUDA + JAX), with
- occupancy grid pruning during ray marching
- early stop during ray color integration
- an inference-time renderer for real-time rendering of NeRFs (CUDA + JAX)
- a GUI for visualizing \& interacting \& exploring NeRFs @seimeicyx
Benchmarks
|
mic |
ficus |
chair |
hotdog |
materials |
drums |
ship |
lego |
average |
@33.7k steps (this codebase) |
37.04 |
33.14 |
35.10 |
37.21 |
29.50 |
25.85 |
30.93 |
35.95 |
33.09 |
@51.2k steps (this codebase) |
37.07 |
33.17 |
35.16 |
37.26 |
29.50 |
25.86 |
30.94 |
36.03 |
33.124 |
paper (instant-ngp) |
36.22 |
33.51 |
35.00 |
37.40 |
29.78 |
26.02 |
31.10 |
36.39 |
33.176 |
For each scene, the network is trained on 100 training images (800x800 each) for 30k steps with
default parameters, reported PSNR is averaged across 200 test images.
Environment Setup
jaxngp manages environments with Nix, but it's also possible to setup the environment with any other package manager (e.g. Conda).
With Nix (recommended)
- Install Nix with the official installer or the nix-installer.
- With the
nix
executable available, clone this repository and setup environment:
$ git clone https://github.com/blurgyy/jaxngp.git
$ cd jaxngp/
$ NIXPKGS_ALLOW_UNFREE=1 nix develop --impure
This will download (or build if necessary) all the dependencies, and opens a new shell with all the dependencies configured.
Note: to avoid the built environment being garbage collected when nix gc
or nix-collect-garbage
is called, append a --profile <PATH>
argument:
$ NIXPKGS_ALLOW_UNFREE=1 nix develop --impure --profile .git/devshell.profile
With Conda
TODO
Running
The program's entrance is at python3 -m app.nerf
. It provides three subcommands: train
, test
, and gui
. Pass -h|--help
to any of the subcommand to see its usage, e.g.:
python3 -m app.nerf train --help
```markdown
usage: __main__.py train [-h] --exp-dir PATH [--raymarch.diagonal-n-steps INT]
[--raymarch.perturb | --raymarch.no-perturb]
[--raymarch.density-grid-res INT] [--render.bg FLOAT FLOAT FLOAT]
[--render.random-bg | --render.no-random-bg]
[--scene.sharpness-threshold FLOAT] [--scene.world-scale FLOAT]
[--scene.resolution-scale FLOAT] [--scene.camera-near FLOAT]
[--logging {DEBUG,INFO,WARN,WARNING,ERROR,CRITICAL}] [--seed INT]
[--summary | --no-summary] [--frames-val PATH [PATH ...]]
[--ckpt {None}|PATH] [--lr FLOAT] [--tv-scale FLOAT] [--bs INT]
[--n-epochs INT] [--n-batches INT] [--data-loop INT] [--validate-every INT]
[--keep INT] [--keep-every {None}|INT]
[--raymarch-eval.diagonal-n-steps INT]
[--raymarch-eval.perturb | --raymarch-eval.no-perturb]
[--raymarch-eval.density-grid-res INT] [--render-eval.bg FLOAT FLOAT FLOAT]
[--render-eval.random-bg | --render-eval.no-random-bg]
PATH [PATH ...]
╭─ positional arguments ───────────────────────────────────────────────────────────────────────────╮
│ PATH [PATH ...] directories or transform.json files containing data for training │
│ (required) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ arguments ──────────────────────────────────────────────────────────────────────────────────────╮
│ -h, --help show this help message and exit │
│ --exp-dir PATH experiment artifacts are saved under this directory (required) │
│ --frames-val PATH [PATH ...] │
│ directories or transform.json files containing data for validation │
│ (default: ) │
│ --ckpt {None}|PATH if specified, continue training from this checkpoint (default: None) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ raymarch arguments ─────────────────────────────────────────────────────────────────────────────╮
│ raymarching/rendering options during training │
│ ──────────────────────────────────────────────────────────────────────────────────────────────── │
│ --raymarch.diagonal-n-steps INT │
│ for calculating the length of a minimal ray marching step, the NGP paper │
│ uses 1024 (appendix E.1) (default: 1024) │
│ --raymarch.perturb, --raymarch.no-perturb │
│ whether to fluctuate the first sample along the ray with a tiny │
│ perturbation (default: True) │
│ --raymarch.density-grid-res INT │
│ resolution for the auxiliary density/occupancy grid, the NGP paper uses │
│ 128 (appendix E.2) (default: 128) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ render arguments ───────────────────────────────────────────────────────────────────────────────╮
│ raymarching/rendering options during training │
│ ──────────────────────────────────────────────────────────────────────────────────────────────── │
│ --render.bg FLOAT FLOAT FLOAT │
│ background color for transparent parts of the image, has no effect if │
│ `random_bg` is True (default: 1.0 1.0 1.0) │
│ --render.random-bg, --render.no-random-bg │
│ ignore `bg` specification and use random color for transparent parts of │
│ the image (default: True) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ scene arguments ────────────────────────────────────────────────────────────────────────────────╮
│ raymarching/rendering options during training │
│ ──────────────────────────────────────────────────────────────────────────────────────────────── │
│ --scene.sharpness-threshold FLOAT │
│ images with sharpness lower than this value will be discarded (default: │
│ -1.0) │
│ --scene.world-scale FLOAT │
│ scale both the scene's camera positions and bounding box with this │
│ factor (default: 1.0) │
│ --scene.resolution-scale FLOAT │
│ scale input images in case they are too large, camera intrinsics are │
│ also scaled to match the updated image resolution. (default: 1.0) │
│ --scene.camera-near FLOAT │
│ (default: 0.3) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ common arguments ───────────────────────────────────────────────────────────────────────────────╮
│ --logging {DEBUG,INFO,WARN,WARNING,ERROR,CRITICAL} │
│ log level (default: INFO) │
│ --seed INT random seed (default: 1000000007) │
│ --summary, --no-summary │
│ display model information after model init (default: False) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ train arguments ────────────────────────────────────────────────────────────────────────────────╮
│ training hyper parameters │
│ ──────────────────────────────────────────────────────────────────────────────────────────────── │
│ --lr FLOAT learning rate (default: 0.01) │
│ --tv-scale FLOAT scalar multiplied to total variation loss, set this to a positive value │
│ to enable calculation of TV loss (default: 0.0) │
│ --bs INT batch size (default: 1048576) │
│ --n-epochs INT training epochs (default: 50) │
│ --n-batches INT batches per epoch (default: 1024) │
│ --data-loop INT loop within training data for this number of iterations, this helps │
│ reduce the effective dataloader overhead. (default: 1) │
│ --validate-every INT will validate every `validate_every` epochs, set this to a large value │
│ to disable validation (default: 10) │
│ --keep INT number of latest checkpoints to keep (default: 1) │
│ --keep-every {None}|INT │
│ how many epochs should a new checkpoint to be kept (in addition to │
│ keeping the last `keep` checkpoints) (default: 8) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ raymarch-eval arguments ────────────────────────────────────────────────────────────────────────╮
│ raymarching/rendering options for validating during training │
│ ──────────────────────────────────────────────────────────────────────────────────────────────── │
│ --raymarch-eval.diagonal-n-steps INT │
│ for calculating the length of a minimal ray marching step, the NGP paper │
│ uses 1024 (appendix E.1) (default: 1024) │
│ --raymarch-eval.perturb, --raymarch-eval.no-perturb │
│ whether to fluctuate the first sample along the ray with a tiny │
│ perturbation (default: False) │
│ --raymarch-eval.density-grid-res INT │
│ resolution for the auxiliary density/occupancy grid, the NGP paper uses │
│ 128 (appendix E.2) (default: 128) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ render-eval arguments ──────────────────────────────────────────────────────────────────────────╮
│ raymarching/rendering options for validating during training │
│ ──────────────────────────────────────────────────────────────────────────────────────────────── │
│ --render-eval.bg FLOAT FLOAT FLOAT │
│ background color for transparent parts of the image, has no effect if │
│ `random_bg` is True (default: 0.0 0.0 0.0) │
│ --render-eval.random-bg, --render-eval.no-random-bg │
│ ignore `bg` specification and use random color for transparent parts of │
│ the image (default: False) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
```
Above is just an example and might not reflect the state of the latest codebase.
Examples
train
- Just train with the default parameters on the
lego
scene from the NeRF-synthetic dataset:
$ python3 -m app.nerf train data/nerf_synthetic/lego/transforms_train.json --exp-dir=logs/lego
- Train for 10 epochs, with a batch size of 262144, on all the 400 (100*train + 100*validation + 200*test) images from the
lego
scene:
$ python3 -m app.nerf train data/nerf_synthetic/lego --exp-dir=logs/lego-trainvaltest --{n-epochs=10,bs=262144}
- Train on the training and validation splits of the
drums
scene, with a weight of 1e-5 on the Total Variation (TV) loss (by default this weight is 0):
$ python3 -m app.nerf train data/nerf_synthetic/drums/transforms_{train,val}.json --exp-dir=logs/drums-trainval --tv-scale=1e-5
- Train on the training split of the
mic
scene, validate with the validation split, validate after every epoch:
$ python3 -m app.nerf train data/nerf_synthetic/mic/transforms_train.json --frames-val=data/nerf_synthetic/mic/transforms_val.json --exp-dir=logs/mic --validate-every=1
Note: The validated images are logged to tensorboard, located under --exp-dir
's logs/
directory. View it in browser with:
$ tensorboard serve --logdir logs/mic/logs/ --bind_all
TensorBoard 2.10.0 at http://localhost:6006/ (Press CTRL+C to quit)
test
- Test using the latest checkpoint under
logs/mic
directory, with the camera intrinsics and extrinsics of the mic
scene's test split
$ python3 -m app.nerf test data/nerf_synthetic/mic/transforms_test.json --ckpt=logs/mic/ --exp-dir=output
- Test on the
mic
scene with given camera extrinsics, but override the camera's resolution to 1920x1080, and use white as background color:
$ python3 -m app.nerf test data/nerf_synthetic/mic/transforms_test.json --ckpt=logs/mic/ --exp-dir=output --camera-override.{width=1920,height=1080} --render.bg 1 1 1
- Test with a generated orbiting trajectory (see Demos for an example) on the
mic
scene, with resolution 1920x1080:
$ python3 -m app.nerf test data/nerf_synthetic/mic --trajectory=orbit --ckpt=logs/mic/ --exp-dir=output --camera-override.{width=1920,height=1080}
gui
Note: The gui
subcommand accepts all the parameters of the train
subcommand, and additionally a --viewport
parameter (but the default values of --viewport
are sane enough to leave as-is).
Running on Custom Data
A helper CLI (just a colmap wrapper via pycolmap) is provided for creating an Instant-NGP-compatible scene from a casually captured video or a directory of images:
$ python3 -m utils create-scene --help # see full usage
usage: ...
$
$ # capture or download a video
$ mkdir -p data/_src
$ curl https://github.com/blurgyy/jaxngp/assets/44701880/022a7b3c-344d-418f-aba0-0ccb9bfeb374 -Lo data/_src/gundam.mp4
$
$ # create a scene from the video, set scene bound to 16, with a background color model
$ python3 -m utils create-scene data/_src/gundam.mp4 --root-dir=data/gundam --matcher=Sequential --fps=5 --bound=16 --bg
[...]
After the scene has been created, the rest (training/validating/testing) are the same:
$ # train on all the registered images
$ python3 -m utils train data/gundam --exp-dir=logs/gundam
[...]
$
$ # Render novel views, with a resolution of 1920x1080, save results as images and a video (video shown below in the Demo section)
$ python3 -m app.nerf test data/gundam --{ckpt,exp-dir}=logs/gundam --trajectory=orbit --camera-override.{width=1920,height=1080,no-distortion} --orbit.high=1 --save-as="video and images"
Demos