v objective diffusion inference code for JAX.

Overview

v-diffusion-jax

v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman).

The models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI).

Dependencies

  • JAX (installation instructions)

  • dm-haiku, einops, numpy, optax, Pillow, tqdm (install with pip install)

  • CLIP_JAX (https://github.com/kingoflolz/CLIP_JAX), and its additional pip-installable dependencies: ftfy, regex, torch, torchvision (it does not need GPU PyTorch). If you git clone --recursive this repo, it should fetch CLIP_JAX automatically.

Model checkpoints:

  • Danbooru SFW 128x128, SHA-256 8551fe663dae988e619444efd99995775c7618af2f15ab5d8caf6b123513c334

  • ImageNet 128x128, SHA-256 4fc7c817b9aaa9018c6dbcbf5cd444a42f4a01856b34c49039f57fe48e090530

  • WikiArt 128x128, SHA-256 8fbe4e0206262996ff76d3f82a18dc67d3edd28631d4725e0154b51d00b9f91a

  • WikiArt 256x256, SHA-256 ebc6e77865bbb2d91dad1a0bfb670079c4992684a0e97caa28f784924c3afd81

Sampling

Example

If the model checkpoints are stored in checkpoints/, the following will generate an image:

./clip_sample.py "a friendly robot, watercolor by James Gurney" --model wikiart_256 --seed 0

If they are somewhere else, you need to specify the path to the checkpoint with --checkpoint.

Unconditional sampling

usage: sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--eta ETA] --model
                 {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
                 [--steps STEPS]

--batch-size: sample this many images at a time (default 1)

--checkpoint: manually specify the model checkpoint file

--eta: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps.

--model: specify the model to use

-n: sample until this many images are sampled (default 1)

--seed: specify the random seed (default 0)

--steps: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)

CLIP guided sampling

CLIP guided sampling lets you generate images with diffusion models conditional on the output matching a text prompt.

usage: clip_sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT]
                      [--clip-guidance-scale CLIP_GUIDANCE_SCALE] [--eta ETA] --model
                      {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
                      [--steps STEPS]
                      prompt

clip_sample.py has the same options as sample.py and these additional ones:

prompt: the text prompt to use

--clip-guidance-scale: how strongly the result should match the text prompt (default 1000)

You might also like...
Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Official implementation of
Official implementation of "A Unified Objective for Novel Class Discovery", ICCV2021 (Oral)

A Unified Objective for Novel Class Discovery This is the official repository for the paper: A Unified Objective for Novel Class Discovery Enrico Fini

Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Information-Theoretic Multi-Objective Bayesian Optimization with Continuous Approximations

Information-Theoretic Multi-Objective Bayesian Optimization with Continuous Approximations Requirements The code is implemented in Python and requires

Pytorch implementation of "MOSNet: Deep Learning based Objective Assessment for Voice Conversion"

MOSNet pytorch implementation of "MOSNet: Deep Learning based Objective Assessment for Voice Conversion" https://arxiv.org/abs/1904.08352 Dependency L

Multi-objective gym environments for reinforcement learning.
Multi-objective gym environments for reinforcement learning.

MO-Gym: Multi-Objective Reinforcement Learning Environments Gym environments for multi-objective reinforcement learning (MORL). The environments follo

Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)
Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)

Protein GLM (wip) Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capabil

Paaster is a secure by default end-to-end encrypted pastebin built with the objective of simplicity.
Paaster is a secure by default end-to-end encrypted pastebin built with the objective of simplicity.

Follow the development of our desktop client here Paaster Paaster is a secure by default end-to-end encrypted pastebin built with the objective of sim

Comments
  • Getting error on clip_sample

    Getting error on clip_sample

    Targeting a colab notebook for your code. First tried this colab that begins with:

    "https://github.com/crowsonkb/v-diffusion-jax Original Colab from BoneAmputee, edited by danielrussruss to allow for batch_size > 1, displaying and saving progress. Update: Fixed a problem blocking ability to change prompt without re-running some additional cells."

    and got this error: "RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: UNKNOWN: GetConvolveAlgorithms failed."

    Then, I went to your repository (this one), and just hand-crafted a colab notebook, but then it fails with the same error when I try to run clip_sample.

    Any ideas? If you know of a working colab notebook, let me know. Running a v100

    opened by metaphorz 3
  • explaining the continuous time ddim sampling

    explaining the continuous time ddim sampling

    Hi! I am looking for a code for DDIM sampling when the noise scheduler is continuous, and I came to this repository. I have a question on https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/sampling.py#L11, since the model output is v, are you suggesting that the parametrization of the model is that the Unet output the signal to noise ratio, rather than eps? Thanks!

    opened by XavierXiao 0
  • enable cog

    enable cog

    Hey @crowsonkb ! 👋

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can try out your model! View it here: https://replicate.ai/crowsonkb/v-diffusion-jax

    Do claim your page here so you can own the page, customise the Example gallery as you like, and we'll feature it on our website and tweet about it too.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. 😊

    opened by chenxwh 0
Owner
Katherine Crowson
AI/generative artist.
Katherine Crowson
Minimal diffusion models - Minimal code and simple experiments to play with Denoising Diffusion Probabilistic Models (DDPMs)

Minimal code and simple experiments to play with Denoising Diffusion Probabilist

Rithesh Kumar 16 Oct 6, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
Torchserve server using a YoloV5 model running on docker with GPU and static batch inference to perform production ready inference.

Yolov5 running on TorchServe (GPU compatible) ! This is a dockerfile to run TorchServe for Yolo v5 object detection model. (TorchServe (PyTorch librar

null 82 Nov 29, 2022
Monocular 3D pose estimation. OpenVINO. CPU inference or iGPU (OpenCL) inference.

human-pose-estimation-3d-python-cpp RealSenseD435 (RGB) 480x640 + CPU Corei9 45 FPS (Depth is not used) 1. Run 1-1. RealSenseD435 (RGB) 480x640 + CPU

Katsuya Hyodo 8 Oct 3, 2022
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
Data-depth-inference - Data depth inference with python

Welcome! This readme will guide you through the use of the code in this reposito

Marco 3 Feb 8, 2022
Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).

Relation Prediction as an Auxiliary Training Objective for Knowledge Base Completion This repo provides the code for the paper Relation Prediction as

Facebook Research 85 Jan 2, 2023
[ICML 2020] Prediction-Guided Multi-Objective Reinforcement Learning for Continuous Robot Control

PG-MORL This repository contains the implementation for the paper Prediction-Guided Multi-Objective Reinforcement Learning for Continuous Robot Contro

MIT Graphics Group 65 Jan 7, 2023