Official implementation of VaxNeRF (Voxel-Accelearated NeRF).

Overview

VaxNeRF

Paper | Google Colab Open In Colab

This is the official implementation of VaxNeRF (Voxel-Accelearated NeRF).
VaxNeRF provides very fast training and slightly higher scores compared to original (Jax)NeRF!!

Updates!

Visual Hull (1sec)
NeRF (10min)
VaxNeRF (10min)
Vax-MipNeRF (10min)


(The results of Vax-MipNeRF are also included in this figure.)

Installation

Please see the README of JaxNeRF.

The jax and jaxlib versions that we have tested are as follows.

jax                     0.2.24
jaxlib                  0.1.69+cuda111
jax                     0.2.17
jaxlib                  0.1.65+cuda110

Quick start

Training

# make a bounding volume voxel using Visual Hull
python visualhull.py \
    --config configs/demo \
    --data_dir data/nerf_synthetic/lego \
    --voxel_dir data/voxel_dil7/lego \
    --dilation 7 \
    --thresh 1. \
    --alpha_bkgd

# train VaxNeRF
python train.py \
    --config configs/demo \
    --data_dir data/nerf_synthetic/lego \
    --voxel_dir data/voxel_dil7/lego \
    --train_dir logs/lego_vax_c800 \
    --num_coarse_samples 800 \
    --render_every 2500

Evaluation

python eval.py \
    --config configs/demo \
    --data_dir data/nerf_synthetic/lego \
    --voxel_dir data/voxel_dil7/lego \
    --train_dir logs/lego_vax_c800 \
    --num_coarse_samples 800

Try other NeRFs

Original NeRF

python train.py \
    --config configs/demo \
    --data_dir data/nerf_synthetic/lego \
    --train_dir logs/lego_c64f128 \
    --num_coarse_samples 64 \
    --num_fine_samples 128 \
    --render_every 2500

VaxNeRF with hierarchical sampling

# small `num_xx_samples` needs more dilated voxel (see our paper)
python visualhull.py \
    --config configs/demo \
    --data_dir data/nerf_synthetic/lego \
    --voxel_dir data/voxel_dil47/lego \
    --dilation 47 \
    --thresh 1. \
    --alpha_bkgd

# train VaxNeRF
python train.py \
    --config configs/demo \
    --data_dir data/nerf_synthetic/lego \
    --voxel_dir data/voxel_dil47/lego \
    --train_dir logs/lego_vax_c64f128 \
    --num_coarse_samples 64 \
    --num_fine_samples 128 \
    --render_every 2500

Option details

Visual Hull

  • Use --dilation 11 / --dilation 51 for NSVF-Synthetic dataset for training VaxNeRF without / with hierarchical sampling.
  • The following options were used
  • Since the Lifestyle, Spaceship, Steamtrain scenes (included in the NSVF dataset) do not have alpha channel, please use following options and remove --alpha_bkgd option.
    • Lifestyle: --thresh 0.95, Spaceship: --thresh 0.9, Steamtrain: --thresh 0.95

NeRFs

  • We used --small_lr_at_first option for original NeRF training on the Robot and Spaceship scenes to avoid local minimum.

Code modification from JaxNeRF

  • You can see the main difference between (Jax)NeRF (jaxnerf branch) and VaxNeRF (vaxnerf branch) here
  • The main branch (derived from the vaxnerf branch) contains the following features.
    • Support for original NeRF
    • Support for VaxNeRF with hierarchical sampling
    • Support for the NSVF-Synthetic dataset
    • Visualization of number of sampling points evaluated by MLP (VaxNeRF)
    • Automatic choice of the number of sampling points to be evaluated (VaxNeRF)

Citation

Please use the following bibtex for citations:

@article{kondo2021vaxnerf,
  title={VaxNeRF: Revisiting the Classic for Voxel-Accelerated Neural Radiance Field},
  author={Kondo, Naruya and Ikeda, Yuya and Tagliasacchi, Andrea and Matsuo, Yutaka and Ochiai, Yoichi and Gu, Shixiang Shane},
  journal={arXiv preprint arXiv:2111.13112},
  year={2021}
}

and also cite the original NeRF paper and JaxNeRF implementation:

@inproceedings{mildenhall2020nerf,
  title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
  author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
  year={2020},
  booktitle={ECCV},
}

@software{jaxnerf2020github,
  author = {Boyang Deng and Jonathan T. Barron and Pratul P. Srinivasan},
  title = {{JaxNeRF}: an efficient {JAX} implementation of {NeRF}},
  url = {https://github.com/google-research/google-research/tree/master/jaxnerf},
  version = {0.0},
  year = {2020},
}

Acknowledgement

We'd like to express deep thanks to the inventors of NeRF and JaxNeRF.

Have a good VaxNeRF'ed life!

Comments
  • Does all sampling positions need be fed into the NeRF model after using visual hull?

    Does all sampling positions need be fed into the NeRF model after using visual hull?

    Thanks for your great work. I want to know whether you eliminate the possitions which are not in visual hull or not before you put them into NeRF model. If you eliminate them, there will be less positions for computing. But this will lead to different number of positions for each ray. So I want to know how you solve this problem?

    opened by DRosemei 5
  • Fails to train on GPU

    Fails to train on GPU

    Hi,

    Congrats on your work! I'm trying to Vax the lego dataset from NeRF on a GCP vm with A100. I have followed instructions from JaxNeRF repo and installed cudatoolkit 11.0.221 through Conda. The visuals hull was created successfully but when I run training I get the following message: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NOT_INITIALIZED: initialization error 140511409292096 xla_bridge.py:243] Unable to initialize backend 'gpu': FAILED_PRECONDITION: No visible GPU devices

    Here's the output of nvidia-smi:

    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  NVIDIA A100-SXM...  On   | 00000000:00:04.0 Off |                    0 |
    | N/A   31C    P0    44W / 400W |    111MiB / 40536MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
                                                                                   
    +-----------------------------------------------------------------------------+
    | Processes:                                                                  |
    |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
    |        ID   ID                                                   Usage      |
    |=============================================================================|
    |    0   N/A  N/A      1178      G   /usr/lib/xorg/Xorg                 95MiB |
    |    0   N/A  N/A      1218      G   /usr/bin/gnome-shell               14MiB |
    +-----------------------------------------------------------------------------+
    

    This driver and CUDA is what I've manually installed but pasted the output here to show that nvidia-smi works. For VaxNeRF I've installed cudatoolkit and cud (8.2) through Conda.

    I want to train on my own data after I've trained it on lego. Will it be possible for you to add instructions to achieve the same?

    Thanks!

    opened by sixftninja 4
  • question of digitize function

    question of digitize function

    Hello, I cannot understand this function clearly. I know that digitize is used to convert p's location to voxel index. But why do you add rsize to p ?p+rsize.

    Rsize is half value of far-near, which will change p's location. Then, you limit voxelized p to [0, vsize-1], but in real world coordinator(e.g. llff), voxelized p in range [0, vsize-1] may not contain object.

    Or rsize here is just an empirical offset?

    @partial(jit, static_argnums=(1,2,))
    def digitize(p, rsize, vsize):
        p = jnp.round((p+rsize) * (vsize/(rsize*2)))
        return jnp.clip(p.astype(jnp.uint16), 0, vsize-1)\
    
    opened by zhangjian94cn 2
  • Does not work on real scenes

    Does not work on real scenes

    I tried the visual hull code on a few real scenes (from LLFF and Tanks&Temple datasets), all scenes return wrong visual hulls (because of ground plane and background noise). Does this approach only work on synthetic data?

    opened by qiaosongwang 2
  • using the concept of voxel acceleration for Deblr-NeRF

    using the concept of voxel acceleration for Deblr-NeRF

    Hi @naruya,

    This is amazing work. I am working on deblurring NeRF using blurry input scenes. Is it possible to use the concept of VaxNeRF in thisnerf model?

    Thanks

    opened by addy1997 1
  • TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.

    TypeError: broadcast_to requires ndarray or scalar arguments, got at position 0.

    (jaxnerf) ➜  VaxNeRF git:(main) python train.py \
        --config configs/demo \
        --data_dir data/nerf_synthetic/lego \
        --train_dir logs/lego_c64f128 \
        --num_coarse_samples 64 \
        --num_fine_samples 128 \
        --render_every 2500
    I0530 19:00:30.947573 140039295452352 xla_bridge.py:330] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
    I0530 19:00:31.048327 140039295452352 xla_bridge.py:330] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
    /root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:501: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
      warnings.warn(
    /root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:514: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
      warnings.warn(
    Traceback (most recent call last):
      File "train.py", line 316, in <module>
        app.run(main)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
        _run_main(main, args)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
        sys.exit(main(argv))
      File "train.py", line 312, in main
        train(FLAGS.max_steps)
      File "train.py", line 146, in train
        model, variables = models.get_model(key, dataset.peek(), FLAGS)
      File "/root/workspace/VaxNeRF/nerf/models.py", line 33, in get_model
        return model_dict[args.model](key, example_batch, args)
      File "/root/workspace/VaxNeRF/nerf/models.py", line 322, in construct_nerf
        init_variables = model.init(
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/linen/module.py", line 1224, in init
        _, v_out = self.init_with_output(
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/linen/module.py", line 1191, in init_with_output
        return self.apply(
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/linen/module.py", line 1156, in apply
        return apply(
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/core/scope.py", line 831, in wrapper
        y = fn(root, *args, **kwargs)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/linen/module.py", line 1440, in scope_fn
        return fn(module.clone(parent=scope), *args, **kwargs)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/linen/transforms.py", line 1239, in wrapped_fn
        return prewrapped_fn(self, *args, **kwargs)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
        return self._call_wrapped_method(fun, args, kwargs)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/flax/linen/module.py", line 648, in _call_wrapped_method
        y = fun(self, *args, **kwargs)
      File "/root/workspace/VaxNeRF/nerf/models.py", line 153, in __call__
        comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
      File "/root/workspace/VaxNeRF/nerf/model_utils.py", line 173, in volumetric_rendering
        jnp.broadcast_to([1e10], z_vals[Ellipsis, :1].shape)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 345, in _broadcast_to
        _check_arraylike("broadcast_to", arr)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 298, in _check_arraylike
        raise TypeError(msg.format(fun_name, type(arg), pos))
    jax._src.traceback_util.UnfilteredStackTrace: TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "train.py", line 316, in <module>
        app.run(main)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
        _run_main(main, args)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
        sys.exit(main(argv))
      File "train.py", line 312, in main
        train(FLAGS.max_steps)
      File "train.py", line 146, in train
        model, variables = models.get_model(key, dataset.peek(), FLAGS)
      File "/root/workspace/VaxNeRF/nerf/models.py", line 33, in get_model
        return model_dict[args.model](key, example_batch, args)
      File "/root/workspace/VaxNeRF/nerf/models.py", line 322, in construct_nerf
        init_variables = model.init(
      File "/root/workspace/VaxNeRF/nerf/models.py", line 153, in __call__
        comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
      File "/root/workspace/VaxNeRF/nerf/model_utils.py", line 173, in volumetric_rendering
        jnp.broadcast_to([1e10], z_vals[Ellipsis, :1].shape)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 345, in _broadcast_to
        _check_arraylike("broadcast_to", arr)
      File "/root/.pyenv/versions/jaxnerf/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 298, in _check_arraylike
        raise TypeError(msg.format(fun_name, type(arg), pos))
    TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.
    
    opened by naruya 1
  • vaxnerf can not reach convergence

    vaxnerf can not reach convergence

    this can give a coarse result after 2500 step python train.py
    --config configs/demo
    --data_dir data/nerf_synthetic/lego
    --train_dir logs/lego_c64f128
    --num_coarse_samples 64
    --num_fine_samples 128
    --render_every 2500

    but using voxel to train can not reach convergence python train.py
    --config configs/demo
    --data_dir data/nerf_synthetic/lego
    --voxel_dir data/voxel_dil7/lego
    --train_dir logs/lego_vax_c800
    --num_coarse_samples 800
    --render_every 2500

    the result is totally a white background

    opened by YuiNsky 0
Owner
naruya
May the "Metaverse" be a warm-hearted world. / first-year master's student
naruya
Code for "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clouds", CVPR 2021

PV-RAFT This repository contains the PyTorch implementation for paper "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clou

Yi Wei 43 Dec 5, 2022
Compute descriptors for 3D point cloud registration using a multi scale sparse voxel architecture

MS-SVConv : 3D Point Cloud Registration with Multi-Scale Architecture and Self-supervised Fine-tuning Compute features for 3D point cloud registration

null 42 Jul 25, 2022
Voxel Transformer for 3D object detection

Voxel Transformer This is a reproduced repo of Voxel Transformer for 3D object detection. The code is mainly based on OpenPCDet. Introduction We provi

null 173 Dec 25, 2022
Python code to fuse multiple RGB-D images into a TSDF voxel volume.

Volumetric TSDF Fusion of RGB-D Images in Python This is a lightweight python script that fuses multiple registered color and depth images into a proj

Andy Zeng 845 Jan 3, 2023
for taichi voxel-challange event

Taichi Voxel Challenge Figure: result of python3 example6.py. Please replace the image above (demo.jpg) with yours, so that other people can immediate

Liming Xu 20 Nov 26, 2022
Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR 2022)

Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR2022)[paper] Authors: Chenhang He, Ruihuang Li, Shuai Li, L

Billy HE 141 Dec 30, 2022
Unofficial & improved implementation of NeRF--: Neural Radiance Fields Without Known Camera Parameters

[Unofficial code-base] NeRF--: Neural Radiance Fields Without Known Camera Parameters [ Project | Paper | Official code base ] ⬅️ Thanks the original

Jianfei Guo 239 Dec 22, 2022
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

null 551 Dec 29, 2022
Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation

Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation This project attempted to implement the paper Putting NeRF on a

null 254 Dec 27, 2022
A PyTorch implementation of NeRF (Neural Radiance Fields) that reproduces the results.

NeRF-pytorch NeRF (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are

Yen-Chen Lin 3.2k Jan 8, 2023
Pytorch implementation for A-NeRF: Articulated Neural Radiance Fields for Learning Human Shape, Appearance, and Pose

A-NeRF: Articulated Neural Radiance Fields for Learning Human Shape, Appearance, and Pose Paper | Website | Data A-NeRF: Articulated Neural Radiance F

Shih-Yang Su 172 Dec 22, 2022
A minimal TPU compatible Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

NeRF Minimal Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis. Result of Tiny-NeRF RGB Depth

Soumik Rakshit 11 Jul 24, 2022
This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametric Head Model (CVPR 2022)".

HeadNeRF: A Real-time NeRF-based Parametric Head Model This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametr

null 294 Jan 1, 2023
(Arxiv 2021) NeRF--: Neural Radiance Fields Without Known Camera Parameters

NeRF--: Neural Radiance Fields Without Known Camera Parameters Project Page | Arxiv | Colab Notebook | Data Zirui Wang¹, Shangzhe Wu², Weidi Xie², Min

Active Vision Laboratory 411 Dec 26, 2022
PlenOctrees: NeRF-SH Training & Conversion

PlenOctrees Official Repo: NeRF-SH training and conversion This repository contains code to train NeRF-SH and to extract the PlenOctree, constituting

Alex Yu 323 Dec 29, 2022
NeRF Meta-Learning with PyTorch

NeRF Meta Learning With PyTorch nerf-meta is a PyTorch re-implementation of NeRF experiments from the paper "Learned Initializations for Optimizing Co

Sanowar Raihan 78 Dec 18, 2022
Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.

This repository contains the code release for Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. This implementation is written in JAX, and is a fork of Google's JaxNeRF implementation. Contact Jon Barron if you encounter any issues.

Google 625 Dec 30, 2022
Code release for DS-NeRF (Depth-supervised Neural Radiance Fields)

Depth-supervised NeRF: Fewer Views and Faster Training for Free Project | Paper | YouTube Pytorch implementation of our method for learning neural rad

null 524 Jan 8, 2023
NeRF visualization library under construction

NeRF visualization library using PlenOctrees, under construction pip install nerfvis Docs will be at: https://nerfvis.readthedocs.org import nerfvis s

Alex Yu 196 Jan 4, 2023