Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation

Overview

Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation

Open in Streamlit Open In Colab

스크린샷 2021-07-04 오후 4 11 51

This project attempted to implement the paper Putting NeRF on a Diet (DietNeRF) in JAX/Flax. DietNeRF is designed for rendering quality novel views in few-shot learning scheme, a task that vanilla NeRF (Neural Radiance Field) struggles. To achieve this, the author coins Semantic Consistency Loss to supervise DietNeRF by prior knowledge from CLIP Vision Transformer. Such supervision enables DietNeRF to learn 3D scene reconstruction with CLIP's prior knowledge on 2D views.

Besides this repo, you can check our write-up and demo here:

🤩 Demo

  1. You can check out our demo in Hugging Face Space
  2. Or you can set up our Streamlit demo locally (model checkpoints will be fetched automatically upon startup)
pip install -r requirements_demo.txt
streamlit run app.py

Streamlit Demo

Implementation

Our code is written in JAX/ Flax and mainly based upon jaxnerf from Google Research. The base code is highly optimized in GPU & TPU. For semantic consistency loss, we utilize pretrained CLIP Vision Transformer from transformers library.

To learn more about DietNeRF, our experiments and implementation, you are highly recommended to check out our very detailed Notion write-up!

스크린샷 2021-07-04 오후 4 11 51

🤗 Hugging Face Model Hub Repo

You can also find our project and our model checkpoints on our Hugging Face Model Hub Repository. The models checkpoints are located in models folder.

Our JAX/Flax implementation currently supports:

Platform Single-Host GPU Multi-Device TPU
Type Single-Device Multi-Device Single-Host Multi-Host
Training Supported Supported Supported Supported
Evaluation Supported Supported Supported Supported

💻 Installation

# Clone the repo
git clone https://github.com/codestella/putting-nerf-on-a-diet
# Create a conda environment, note you can use python 3.6-3.8 as
# one of the dependencies (TensorFlow) hasn't supported python 3.9 yet.
conda create --name jaxnerf python=3.6.12; conda activate jaxnerf
# Prepare pip
conda install pip; pip install --upgrade pip
# Install requirements
pip install -r requirements.txt
# [Optional] Install GPU and TPU support for Jax
# Remember to change cuda101 to your CUDA version, e.g. cuda110 for CUDA 11.0.
!pip install --upgrade jax "jax[cuda110]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# install flax and flax-transformer
pip install flax transformers[flax]

Dataset

Download the datasets from the NeRF official Google Drive. Please download the nerf_synthetic.zip and unzip them in the place you like. Let's assume they are placed under /tmp/jaxnerf/data/.

🤟 How to Train

  1. Train in our prepared Colab notebook: Colab Pro is recommended, otherwise you may encounter out-of-memory
  2. Train locally: set use_semantic_loss=true in your yaml configuration file to enable DietNeRF.
python -m train \
  --data_dir=/PATH/TO/YOUR/SCENE/DATA \ # (e.g. nerf_synthetic/lego)
  --train_dir=/PATH/TO/THE/PLACE/YOU/WANT/TO/SAVE/CHECKPOINTS \
  --config=configs/CONFIG_YOU_LIKE

💎 Experimental Results

Rendered Rendering images by 8-shot learned DietNeRF

DietNeRF has a strong capacity to generalise on novel and challenging views with EXTREMELY SMALL TRAINING SAMPLES!

HOTDOG / DRUM / SHIP / CHAIR / LEGO / MIC

Rendered GIF by occluded 14-shot learned NeRF and Diet-NeRF

We made artificial occlusion on the right side of image (Only picked left side training poses). The reconstruction quality can be compared with this experiment. DietNeRF shows better quality than Original NeRF when It is occluded.

Training poses

LEGO

Diet NeRF NeRF

SHIP

Diet NeRF NeRF

👨‍👧‍👦 Our Team

Teams Members
Project Managing Stella Yang To Watch Our Project Progress, Please Check Our Project Notion
NeRF Team Stella Yang, Alex Lau, Seunghyun Lee, Hyunkyu Kim, Haswanth Aekula, JaeYoung Chung
CLIP Team Seunghyun Lee, Sasikanth Kotti, Khalid Sifullah , Sunghyun Kim
Cloud TPU Team Alex Lau, Aswin Pyakurel, JaeYoung Chung, Sunghyun Kim

*Special mention to our "night owl" contributors 🦉 : Seunghyun Lee, Alex Lau, Stella Yang, Haswanth Aekula

💞 Social Impact

  • Game Industry
  • Augmented Reality Industry
  • Virtual Reality Industry
  • Graphics Industry
  • Online shopping
  • Metaverse
  • Digital Twin
  • Mapping / SLAM

🌱 References

This project is based on “JAX-NeRF”.

@software{jaxnerf2020github,
  author = {Boyang Deng and Jonathan T. Barron and Pratul P. Srinivasan},
  title = {{JaxNeRF}: an efficient {JAX} implementation of {NeRF}},
  url = {https://github.com/google-research/google-research/tree/master/jaxnerf},
  version = {0.0},
  year = {2020},
}

This project is based on “Putting NeRF on a Diet”.

@misc{jain2021putting,
      title={Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis}, 
      author={Ajay Jain and Matthew Tancik and Pieter Abbeel},
      year={2021},
      eprint={2104.00677},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

🔑 License

Apache License 2.0

❤️ Special Thanks

Our Project is motivated by HuggingFace X GoogleAI (JAX) Community Week Event 2021.

We would like to take this chance to thank Hugging Face for organizing such an amazing open-source initiative, Suraj and Patrick for all the technical help. We learn a lot throughout this wonderful experience!

스크린샷 2021-07-04 오후 4 11 51

Finally, we would like to thank Common Computer AI for sponsoring our team access to V100 multi-GPUs server. Thank you so much for your support!

스크린샷

Comments
  • How do we run on own data? with pose estimated from COLMAP

    How do we run on own data? with pose estimated from COLMAP

    It seems the script is only prepared to run on blender synthetic dataset. Could we enable it to run on our own data using COLMAP generated poses_bounds ? Thanks!

    opened by vishnukool 1
  • about the diet-pixel-nerf

    about the diet-pixel-nerf

    hey! is it possible to share the diet loss with pixel-nerf? we try to add diet loss on pixel-nerf but it doesn't work at all. May be we did something wrong?

    opened by chensjtu 1
  • [on Test] Add data_loader part in trainer.py

    [on Test] Add data_loader part in trainer.py

    [This commit is not completed]

    I added and edited data_loader function to work in trainer.py

    There is unsolved problem which need discussion.

    what JaeyoungChung call preload is reading all images when it loads the data in advance. In this case to make all images stored in one jax.numpy.array, we should make all images in a same shape. But I'm not sure that cause that the input(x,y,z view) have to change.

    I need some advise of it.

    Thanks

    opened by minus31 1
  • evaluation result not saved

    evaluation result not saved

    aswin@t1v-n-093bc110-w-0:~/hf-flax-project/putting-nerf-on-a-diet$ /usr/bin/python3 /home/aswin/hf-flax-project/putting-nerf-on-a-diet/main.py WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) 1179 images Pose data loaded - dict_keys(['c2w_mats', 'kinv_mats', 'bds', 'res_mats']) Test PSNR sacre_ius_64_ilr_1_olr_0.0005_bs_64_test (min: 16.222, max: 16.222, cur: 16.222) Train PSNR sacre_ius_64_ilr_1_olr_0.0005_bs_64_train (min: 15.479, max: 18.230, cur: 18.230) 0%|▎ | 500/150000 [21:30<107:11:08, 2.58s/it] Traceback (most recent call last): File "/home/aswin/hf-flax-project/putting-nerf-on-a-diet/main.py", line 84, in my_trainer.train() File "/home/aswin/hf-flax-project/putting-nerf-on-a-diet/src/trainer.py", line 183, in train plt.savefig(os.path.join(temp_eval_result_dir, "{:06d}.png".format(step))) File "/home/aswin/.local/lib/python3.8/site-packages/matplotlib/pyplot.py", line 966, in savefig res = fig.savefig(*args, **kwargs) File "/home/aswin/.local/lib/python3.8/site-packages/matplotlib/figure.py", line 3005, in savefig self.canvas.print_figure(fname, **kwargs) File "/home/aswin/.local/lib/python3.8/site-packages/matplotlib/backend_bases.py", line 2255, in print_figure result = print_method( File "/home/aswin/.local/lib/python3.8/site-packages/matplotlib/backend_bases.py", line 1669, in wrapper return func(*args, **kwargs) File "/home/aswin/.local/lib/python3.8/site-packages/matplotlib/backends/backend_agg.py", line 509, in print_png mpl.image.imsave( File "/home/aswin/.local/lib/python3.8/site-packages/matplotlib/image.py", line 1616, in imsave image.save(fname, **pil_kwargs) File "/usr/local/lib/python3.8/dist-packages/PIL/Image.py", line 2169, in save fp = builtins.open(filename, "w+b") FileNotFoundError: [Errno 2] No such file or directory: 'temp/temp_eval_result_dir/sacre_ius_64_ilr_1_olr_0.0005_bs_64/000500.png'

    opened by masapasa 0
  • Add data_loader logic

    Add data_loader logic

    Add data_loader logic

    Plz check the last commit : 7e87cfc

    I've checked that it is running on the "Dev notebook" on kaggle. But I need help to prove that the training is actually running properly

    Thank you.

    opened by minus31 0
  • Fix invalid Jax Type on single_step

    Fix invalid Jax Type on single_step

    Problems When single_step is jitted, we encountered invalid JAX type error from flax.linen.Module and other error from N_samples.

    Solutions utilize the additional argument (aka static_argnum) from jit so as to treat model and N_samples arguments differently. e.g.

    1. Approach 1 direct assignment: single_step = jit(single_step_wojit, static_argnums=[6, 7])
    2. Approach 2 using jit decorator:@partial(jit, static_argnums=[6,7])

    right now I am using approach 1.

    Reference https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

    opened by riven314 0
  • replace jax.experimental.optix by optimisers && migrate to FLAX model

    replace jax.experimental.optix by optimisers && migrate to FLAX model

    Changes

    1. jax.experimental.optix is deprecated in latest JAX version, use optimisers instead. Unlike optimizers in optix, the optimizers in optimisers requires additional arg step (int) in order to get additional state (e.g. momentum, velocity, annealing learning rate ...etc)
    2. migrating (simplified) NeRF from Haiku framework to FLAX

    Remarks We do noticed there are other FLAX/ JAX based implementation, but we decided not to port them into our codebase at the moment, because their model implementation is way more complex and their interface (e.g. expected input) is quite different from what we have. Also our implementation is a simplified one (refer to https://github.com/tancik/learnit/issues/9), so porting other implementation will introduce additional change (and complexity). We can add them in when we have additional bandwidth, but I dont think it should be a priority rn

    opened by riven314 0
  • pose_spherical function in clip_utils.py

    pose_spherical function in clip_utils.py

    Hi @codestella and @riven314 ,

    Thanks for the work on this repo! I just have a question about the function pose_spherical here. Specifically, it is not clear to me how you derive the C2W transformation from phi and theta. I don't understand what is the spherical coordinates system that you put in the comment of the function. Also what is the coordinate system of your NeRF? Is the OpenGL system? Finally, what is the reasoning behind the matrix transformation: translation, followed by rot_phi, followed by rot_theta?

    Thank you!

    opened by marcelsan 1
Owner
Stella Seoyeon Yang's New Github Account for Research. Ph.D. Candidate Student in SNU, CV lab.
null
Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis (CVPR2022)

Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis Multi-View Consistent Generative Adversarial Networks for 3D-aware

Xuanmeng Zhang 78 Dec 10, 2022
Instant-nerf-pytorch - NeRF trained SUPER FAST in pytorch

instant-nerf-pytorch This is WORK IN PROGRESS, please feel free to contribute vi

null 94 Nov 22, 2022
Pytorch implementation of few-shot semantic image synthesis

Few-shot Semantic Image Synthesis Using StyleGAN Prior Our method can synthesize photorealistic images from dense or sparse semantic annotations using

null 40 Sep 26, 2022
Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing"

One-Shot Free-View Neural Talking Head Synthesis Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Vide

ZLH 406 Dec 23, 2022
Unofficial implementation of One-Shot Free-View Neural Talking Head Synthesis

face-vid2vid Usage Dataset Preparation cd datasets wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl chmod a+rx youtube-dl python load_

worstcoder 68 Dec 30, 2022
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

null 551 Dec 29, 2022
Few-NERD: Not Only a Few-shot NER Dataset

Few-NERD: Not Only a Few-shot NER Dataset This is the source code of the ACL-IJCNLP 2021 paper: Few-NERD: A Few-shot Named Entity Recognition Dataset.

THUNLP 319 Dec 30, 2022
Code for T-Few from "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning"

T-Few This repository contains the official code for the paper: "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learni

null 220 Dec 31, 2022
Official code release for "Learned Spatial Representations for Few-shot Talking-Head Synthesis" ICCV 2021

Official code release for "Learned Spatial Representations for Few-shot Talking-Head Synthesis" ICCV 2021

Moustafa Meshry 16 Oct 5, 2022
SCI-AIDE : High-fidelity Few-shot Histopathology Image Synthesis for Rare Cancer Diagnosis

SCI-AIDE : High-fidelity Few-shot Histopathology Image Synthesis for Rare Cancer Diagnosis Pretrained Models In this work, we created synthetic tissue

Emirhan Kurtuluş 1 Feb 7, 2022
Blender add-on: Add to Cameras menu: View → Camera, View → Add Camera, Camera → View, Previous Camera, Next Camera

Blender add-on: Camera additions In 3D view, it adds these actions to the View|Cameras menu: View → Camera : set the current camera to the 3D view Vie

German Bauer 11 Feb 8, 2022
(CVPR 2022 - oral) Multi-View Depth Estimation by Fusing Single-View Depth Probability with Multi-View Geometry

Multi-View Depth Estimation by Fusing Single-View Depth Probability with Multi-View Geometry Official implementation of the paper Multi-View Depth Est

Bae, Gwangbin 138 Dec 28, 2022
ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers

ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers Official implementation of ViewFormer. ViewFormer is a NeRF-free neural rend

Jonáš Kulhánek 169 Dec 30, 2022
Fre-GAN: Adversarial Frequency-consistent Audio Synthesis

Fre-GAN Vocoder Fre-GAN: Adversarial Frequency-consistent Audio Synthesis Training: python train.py --config config.json Citation: @misc{kim2021frega

Rishikesh (ऋषिकेश) 93 Dec 17, 2022
The Official Implementation of the ICCV-2021 Paper: Semantically Coherent Out-of-Distribution Detection.

SCOOD-UDG (ICCV 2021) This repository is the official implementation of the paper: Semantically Coherent Out-of-Distribution Detection Jingkang Yang,

Jake YANG 62 Nov 21, 2022
From this paper "SESNet: A Semantically Enhanced Siamese Network for Remote Sensing Change Detection"

SESNet for remote sensing image change detection It is the implementation of the paper: "SESNet: A Semantically Enhanced Siamese Network for Remote Se

null 1 May 24, 2022
SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021)

SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021) PyTorch implementation of SnapMix | paper Method Overview Cite

DavidHuang 126 Dec 30, 2022
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

Picsart AI Research (PAIR) 186 Dec 30, 2022
PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021

Neural Scene Flow Fields PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 20

Zhengqi Li 585 Jan 4, 2023