Plenoxels: Radiance Fields without Neural Networks

Overview

Plenoxels: Radiance Fields without Neural Networks

Alex Yu*, Sara Fridovich-Keil*, Matthew Tancik, Qinhong Chen, Benjamin Recht, Angjoo Kanazawa

UC Berkeley

Website and video: https://alexyu.net/plenoxels

arXiv: https://arxiv.org/abs/2112.05131

Note: This JAX implementation is intended to be high-level and user-serviceable, but is much slower (roughly 1 hour per epoch) than the CUDA implementation https://github.com/sxyu/svox2 (roughly 1 minute per epoch), and there is not perfect feature alignment between the two versions. This JAX version can likely be sped up significantly, and we may push performance improvements and extra features in the future. Currently, this version only supports bounded scenes and trains using SGD without regularization.

Citation:

@misc{yu2021plenoxels,
      title={Plenoxels: Radiance Fields without Neural Networks}, 
      author={{Alex Yu and Sara Fridovich-Keil} and Matthew Tancik and Qinhong Chen and Benjamin Recht and Angjoo Kanazawa},
      year={2021},
      eprint={2112.05131},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Setup

We recommend setup with a conda environment, using the packages provided in requirements.txt.

Downloading data

Currently, this implementation only supports NeRF-Blender, which is available at:

https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1

Voxel Optimization (aka Training)

The training file is plenoptimize.py; its flags specify many options to control the optimization (scene, resolution, training duration, when to prune and subdivide voxels, where the training data is, where to save rendered images and model checkpoints, etc.). You can also set the frequency of evaluation, which will compute the validation PSNR and render validation images (comparing the reconstruction to the ground truth).

Comments
  • About the negative density in voxel

    About the negative density in voxel

    Thanks for your astonishing work.

    I notice that some voxels have negative density or zero density. I was wondering how the density is defined if it's negative. According to Formulas 1 and 2 in the plenoxel paper, if density is negative, the result of RGB may be unreasonable. However, I did some experiments to clip the density in the trained model,and the PSNR decreased.

    I didn't find a value domain constraint on density in the code. Maybe I missed something. So I hope to get your opinion.

    Thanks!

    opened by Fangkang515 3
  • Out of memory while trying to allocate 196608000 b

    Out of memory while trying to allocate 196608000 b

    I test the code on Titan V 12G GPU card and 128G memory, it runs out of memory. The paper presents the optimization per scene on a single GPU, I would like to know how much memory need for optimization.

    Traceback (most recent call last): File "plenoptimize.py", line 479, in main() File "plenoptimize.py", line 434, in main mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data) File "plenoptimize.py", line 434, in mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data) File "plenoptimize.py", line 293, in get_loss_rays rgb, disp, acc, weights, voxel_ids = plenoxel.render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv) jax._src.traceback_util.FilteredStackTrace: RuntimeError: Resource exhausted: Out of memory while trying to allocate 196608000 bytes.

    The stack trace above excludes JAX-internal frames. The following is the original exception that occurred, unmodified.


    The above exception was the direct cause of the following exception:

    Traceback (most recent call last): File "plenoptimize.py", line 479, in main() File "plenoptimize.py", line 434, in main mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/api.py", line 809, in value_and_grad_f ans, vjp_py = _vjp(f_partial, *dyn_args) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1878, in _vjp out_primal, out_vjp = ad.vjp(flat_fun, primals_flat) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 498, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "plenoptimize.py", line 434, in mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data) File "plenoptimize.py", line 293, in get_loss_rays rgb, disp, acc, weights, voxel_ids = plenoxel.render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/api.py", line 332, in cache_miss out_flat = xla.xla_call( File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1402, in bind return call_bind(self, fun, *args, **params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1405, in process return trace.process_call(self, fun, tracers, params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 308, in process_call result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1402, in bind return call_bind(self, fun, *args, **params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1405, in process return trace.process_call(self, fun, tracers, params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 191, in process_call jaxpr, out_pvals, consts, env_tracers = self.partial_eval( File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 299, in partial_eval out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux() File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1402, in bind return call_bind(self, fun, *args, **params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1405, in process return trace.process_call(self, fun, tracers, params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 600, in process_call return primitive.impl(f, *tracers, **params) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 579, in _xla_call_impl return compiled_fun(*args) File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 830, in _execute_compiled out_bufs = compiled.execute(input_bufs) RuntimeError: Resource exhausted: Out of memory while trying to allocate 196608000 bytes.

    opened by c03n424 3
  • Using sigmoid(rgb) or clamp(rgb) is correct?

    Using sigmoid(rgb) or clamp(rgb) is correct?

    Hi,

    thanks for your awesome work! in the Plenoctree paper and in this implementation, you apply the sigmoid function to the SH sum: https://github.com/sarafridov/plenoxels/blob/975d2619a75b4e6fadd8b72e9edbf30fdf7e559c/plenoxel.py#L38

    Whereas in the CUDA implementation it seems to me like you apply clamp(x+0.5, 0, inf): https://github.com/sxyu/svox2/blob/59984d6c4fd3d713353bafdcb011646e64647cc7/svox2/csrc/render_lerp_kernel_cuvol.cu#L102 outv += weight * fmaxf(lane_color_total + 0.5f, 0.f); // Clamp to [+0, infty)

    What's the reason for the different implementations and which is more accurate? If it is the second one: why do you add constant 0.5f, whats the intuition for this factor?

    Thanks a lot! :)

    opened by lukasHoel 2
  • interpolating small negative sigmas

    interpolating small negative sigmas

    Hi I'm wondering about the trilinear interpolation of sigmas. From my understanding going from sigma to alpha looks like this: α = 1 - exp(-max(sum(σᵢ * wᵢ), 0))) where σᵢ, wᵢ are the neighbor sigmas and weights for trilinear interpolation.

    I think it's possible for the sigmas to take on small negative values -- which could possibly cancel out other small positive sigmas. For example consider the 8 neighboring sigmas with values ([-0.001, -0.001, -0.001, -0.001, -0.001, -0.001, -0.001, 0.007]) and weights = 1/8 -- the result is a interpolated sigma of zero even though there are non-zero sigmas in the neighborhood.

    Applying the relu function to sigmas before interpolating would fix this problem -- if it is actually a problem. Maybe the the effect is negligible?

    opened by jenkspt 2
  • mul got incompatible shapes

    mul got incompatible shapes

    Exception has occurred: TypeError mul got incompatible shapes for broadcasting: (800, 800, 1, 1536), (800, 1, 800, 3). File "/plenoxels/plenoptimize.py", line 479, in main() File "/plenoxels/plenoptimize.py", line 419, in main mse, data_grad = jax.value_and_grad(lambda grid: get_loss((indices, grid), c2w, gt, H, W, focal, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data) File "/plenoxels/plenoptimize.py", line 419, in mse, data_grad = jax.value_and_grad(lambda grid: get_loss((indices, grid), c2w, gt, H, W, focal, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data) File "/plenoxels/plenoptimize.py", line 285, in get_loss rgb, disp, acc, weights, voxel_ids = plenoxel.render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv) File "/plenoxels/plenoxel.py", line 490, in render_rays pts = rays_o[:, jnp.newaxis, :] + intersections[:, :, jnp.newaxis] * rays_d[:, jnp.newaxis, :] # [n_rays, n_intersections, 3]

    opened by c03n424 2
  • "self.use_sphere_bound" appears twice in definition.

    Hi, thanks for your work. There is a sentence I cannot understand in "opt/util/dataset_base.py". def init(self): self.ndc_coeffs = (-1, -1) self.use_sphere_bound = False self.should_use_background = True # a hint self.use_sphere_bound = True As above, self.use_sphere_bound appears twice. Is it a mistake?

    opened by PSGLGDnoChampion 1
  • About the speed for jax plenoxels

    About the speed for jax plenoxels

    Thanks for your astonishing work.

    “but is much slower (roughly 1 hour per epoch) than the CUDA implementation https://github.com/sxyu/svox2 (roughly 1 minute per epoch)”

    As mentioned above, Jax is much slower than CUDA implementation. As far as I know, Jax could be also accelerated by GPU(cuda). Why is the speed gap so large? And is the fast training of plenoxels due to the acceleration of cuda implementation instead of not using the neural network?

    opened by Fangkang515 1
  • abnormal termination at line 363 in plenoptimize.py

    abnormal termination at line 363 in plenoptimize.py

    rays = np.stack([get_rays_np(H, W, focal, p) for p in train_c2w[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]

    if change it by the following codes:

    _# Precompute all the training rays and shuffle them
    tp = train_c2w[:,:3,:4]
    for p in tp:
        a = get_rays_np(H, W, focal, p)
        rays = np.stack([a], 0)
        print(p)_
    

    the shape of :

    rays is (1, 2, 800, 800, 3)
    multi_lowpass(train_gt[:,None], FLAGS.resolution).astype(np.float32) is (100, 1, 800, 800, 3)
    

    when np.concatenate the two value, it produces: ValueError: all the input array dimensions for the concatenation axis must match exactly,

    opened by c03n424 1
  • question about code

    question about code

    many thanks for your great work! I have question about the code, can you give me some more detailed explanation? ` #Compute when the rays enter and leave the grid

    offsets_pos = jax.lax.stop_gradient((radius - rays_o) / rays_d) offsets_neg = jax.lax.stop_gradient((-radius - rays_o) / rays_d) offsets_in = jax.lax.stop_gradient(jnp.minimum(offsets_pos, offsets_neg)) offsets_out = jax.lax.stop_gradient(jnp.maximum(offsets_pos, offsets_neg)) start = jax.lax.stop_gradient(jnp.max(offsets_in, axis=-1, keepdims=True)) stop = jax.lax.stop_gradient(jnp.min(offsets_out, axis=-1, keepdims=True)) first_intersection = jax.lax.stop_gradient(rays_o + start * rays_d) `

    opened by UestcJay 6
Owner
Sara Fridovich-Keil
Sara Fridovich-Keil
Unofficial & improved implementation of NeRF--: Neural Radiance Fields Without Known Camera Parameters

[Unofficial code-base] NeRF--: Neural Radiance Fields Without Known Camera Parameters [ Project | Paper | Official code base ] ⬅️ Thanks the original

Jianfei Guo 239 Dec 22, 2022
This repository contains the source code for the paper "DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks",

DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks Project Page | Video | Presentation | Paper | Data L

Facebook Research 281 Dec 22, 2022
This is the code for Deformable Neural Radiance Fields, a.k.a. Nerfies.

Deformable Neural Radiance Fields This is the code for Deformable Neural Radiance Fields, a.k.a. Nerfies. Project Page Paper Video This codebase conta

Google 1k Jan 9, 2023
Open source repository for the code accompanying the paper 'Non-Rigid Neural Radiance Fields Reconstruction and Novel View Synthesis of a Deforming Scene from Monocular Video'.

Non-Rigid Neural Radiance Fields This is the official repository for the project "Non-Rigid Neural Radiance Fields: Reconstruction and Novel View Synt

Facebook Research 296 Dec 29, 2022
Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.

This repository contains the code release for Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. This implementation is written in JAX, and is a fork of Google's JaxNeRF implementation. Contact Jon Barron if you encounter any issues.

Google 625 Dec 30, 2022
Code for KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs

KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs Check out the paper on arXiv: https://arxiv.org/abs/2103.13744 This repo cont

Christian Reiser 373 Dec 20, 2022
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

null 551 Dec 29, 2022
Code release for DS-NeRF (Depth-supervised Neural Radiance Fields)

Depth-supervised NeRF: Fewer Views and Faster Training for Free Project | Paper | YouTube Pytorch implementation of our method for learning neural rad

null 524 Jan 8, 2023
PyTorch implementation for MINE: Continuous-Depth MPI with Neural Radiance Fields

MINE: Continuous-Depth MPI with Neural Radiance Fields Project Page | Video PyTorch implementation for our ICCV 2021 paper. MINE: Towards Continuous D

Zijian Feng 325 Dec 29, 2022
BARF: Bundle-Adjusting Neural Radiance Fields 🤮 (ICCV 2021 oral)

BARF ?? : Bundle-Adjusting Neural Radiance Fields Chen-Hsuan Lin, Wei-Chiu Ma, Antonio Torralba, and Simon Lucey IEEE International Conference on Comp

Chen-Hsuan Lin 539 Dec 28, 2022
[ICCV21] Self-Calibrating Neural Radiance Fields

Self-Calibrating Neural Radiance Fields, ICCV, 2021 Project Page | Paper | Video Author Information Yoonwoo Jeong [Google Scholar] Seokjun Ahn [Google

null 381 Dec 30, 2022
[ICCV 2021 Oral] NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo

NerfingMVS Project Page | Paper | Video | Data NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo Yi Wei, Shaohui

Yi Wei 369 Dec 24, 2022
This is the code for "HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields".

HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields This is the code for "HyperNeRF: A Higher-Dimensional

Google 702 Jan 2, 2023
A PyTorch implementation of NeRF (Neural Radiance Fields) that reproduces the results.

NeRF-pytorch NeRF (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are

Yen-Chen Lin 3.2k Jan 8, 2023
pixelNeRF: Neural Radiance Fields from One or Few Images

pixelNeRF: Neural Radiance Fields from One or Few Images Alex Yu, Vickie Ye, Matthew Tancik, Angjoo Kanazawa UC Berkeley arXiv: http://arxiv.org/abs/2

Alex Yu 1k Jan 4, 2023
D-NeRF: Neural Radiance Fields for Dynamic Scenes

D-NeRF: Neural Radiance Fields for Dynamic Scenes [Project] [Paper] D-NeRF is a method for synthesizing novel views, at an arbitrary point in time, of

Albert Pumarola 291 Jan 2, 2023
Code release for NeRF (Neural Radiance Fields)

NeRF: Neural Radiance Fields Project Page | Video | Paper | Data Tensorflow implementation of optimizing a neural representation for a single scene an

null 6.5k Jan 1, 2023
A PyTorch re-implementation of Neural Radiance Fields

nerf-pytorch A PyTorch re-implementation Project | Video | Paper NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis Ben Mildenhall

Krishna Murthy 709 Jan 9, 2023
[ICCV'21] UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction

UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction Project Page | Paper | Supplementary | Video This reposit

null 331 Dec 28, 2022