VaxNeRF
Paper | Google Colab
This is the official implementation of VaxNeRF (Voxel-Accelearated NeRF).
This codebase is implemented using JAX, building on JaxNeRF.
VaxNeRF provides very fast training and slightly higher scores compared to original (Jax)NeRF!!
Installation
Please see the README of JaxNeRF.
Quick start
Training
# make a bounding volume voxel using Visual Hull
python visualhull.py \
--config configs/demo \
--data_dir data/nerf_synthetic/lego \
--voxel_dir data/voxel_dil7/lego \
--dilation 7 \
--thresh 1. \
--alpha_bkgd True
# train VaxNeRF
python train.py \
--config configs/demo \
--data_dir data/nerf_synthetic/lego \
--voxel_dir data/voxel_dil7/lego \
--train_dir logs/lego_vax_c800 \
--num_coarse_samples 800 \
--render_every 2500
Evaluation
python eval.py \
--config configs/demo \
--data_dir data/nerf_synthetic/lego \
--voxel_dir data/voxel_dil7/lego \
--train_dir logs/lego_vax_c800 \
--num_coarse_samples 800
Try other NeRFs
Original NeRF
python train.py \
--config configs/demo \
--data_dir data/nerf_synthetic/lego \
--train_dir logs/lego_c64f128 \
--num_coarse_samples 64 \
--num_fine_samples 128 \
--render_every 2500
VaxNeRF with hierarchical sampling
# hierarchical sampling needs more dilated voxel
python visualhull.py \
--config configs/demo \
--data_dir data/nerf_synthetic/lego \
--voxel_dir data/voxel_dil47/lego \
--dilation 47 \
--thresh 1. \
--alpha_bkgd True
# train VaxNeRF
python train.py \
--config configs/demo \
--data_dir data/nerf_synthetic/lego \
--voxel_dir data/voxel_dil47/lego \
--train_dir logs/lego_vax_c64f128 \
--num_coarse_samples 64 \
--num_fine_samples 128 \
--render_every 2500
Option details
Visual Hull
- Use
--dilation 11
/--dilation 51
for NSVF-Synthetic dataset for training VaxNeRF without / with hierarchical sampling. - The following options were used for the
Lifestyle
,Spaceship
,Steamtrain
scenes (included in the NSVF dataset) because these datasets do not have alpha channel.- Lifestyle:
--thresh 0.95
, Spaceship:--thresh 0.9
, Steamtrain:--thresh 0.95
- Lifestyle:
NeRFs
- We used
--small_lr_at_first
option for original NeRF training on theRobot
andSpaceship
scenes to avoid local minimum.
Code modification from JaxNeRF
- You can see the main difference between (Jax)NeRF (
jaxnerf
branch) and VaxNeRF (vaxnerf
branch) here - The
main
branch (derived from thevaxnerf
branch) contains the following features.- Support for original NeRF
- Support for VaxNeRF with hierarchical sampling
- Support for the NSVF-Synthetic dataset
- Visualization of number of sampling points evaluated by MLP (VaxNeRF)
- Automatic choice of the number of sampling points to be evaluated (VaxNeRF)
Citation
Please use the following bibtex for citations:
@misc{kondo2021vaxnerf,
title={VaxNeRF: Revisiting the Classic for Voxel-Accelerated Neural Radiance Field},
author={Naruya Kondo and Yuya Ikeda and Andrea Tagliasacchi and Yutaka Matsuo and Yoichi Ochiai and Shixiang Shane Gu},
year={2021},
eprint={2111.13112},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
and also cite the original NeRF paper and JaxNeRF implementation:
@inproceedings{mildenhall2020nerf,
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
year={2020},
booktitle={ECCV},
}
@software{jaxnerf2020github,
author = {Boyang Deng and Jonathan T. Barron and Pratul P. Srinivasan},
title = {{JaxNeRF}: an efficient {JAX} implementation of {NeRF}},
url = {https://github.com/google-research/google-research/tree/master/jaxnerf},
version = {0.0},
year = {2020},
}
Acknowledgement
We'd like to express deep thanks to the inventors of NeRF and JaxNeRF.
Have a good VaxNeRF'ed life!