I test the code on Titan V 12G GPU card and 128G memory, it runs out of memory.
The paper presents the optimization per scene on a single GPU, I would like to know how much memory need for optimization.
Traceback (most recent call last):
File "plenoptimize.py", line 479, in
main()
File "plenoptimize.py", line 434, in main
mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data)
File "plenoptimize.py", line 434, in
mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data)
File "plenoptimize.py", line 293, in get_loss_rays
rgb, disp, acc, weights, voxel_ids = plenoxel.render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv)
jax._src.traceback_util.FilteredStackTrace: RuntimeError: Resource exhausted: Out of memory while trying to allocate 196608000 bytes.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "plenoptimize.py", line 479, in
main()
File "plenoptimize.py", line 434, in main
mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/api.py", line 809, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1878, in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 498, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "plenoptimize.py", line 434, in
mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays((indices, grid), batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data)
File "plenoptimize.py", line 293, in get_loss_rays
rgb, disp, acc, weights, voxel_ids = plenoxel.render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/api.py", line 332, in cache_miss
out_flat = xla.xla_call(
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1402, in bind
return call_bind(self, fun, *args, **params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1405, in process
return trace.process_call(self, fun, tracers, params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 308, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1402, in bind
return call_bind(self, fun, *args, **params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1405, in process
return trace.process_call(self, fun, tracers, params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 191, in process_call
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 299, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1402, in bind
return call_bind(self, fun, *args, **params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1405, in process
return trace.process_call(self, fun, tracers, params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/core.py", line 600, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 579, in _xla_call_impl
return compiled_fun(*args)
File "/home/hdr/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 830, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: Resource exhausted: Out of memory while trying to allocate 196608000 bytes.