Finite-temperature variational Monte Carlo calculation of uniform electron gas using neural canonical transformation.

Overview

CoulombGas

Build Status

This code implements the neural canonical transformation approach to the thermodynamic properties of uniform electron gas. Building on JAX, it utilizes (both forward- and backwark-mode) automatic differentiation and the pmap mechanism to achieve a large-scale single-program multiple-data (SPMD) training on multiple GPUs.

Requirements

  • JAX with Nvidia GPU support
  • A handful of GPUs. The more the better :P
  • haiku
  • optax
  • To analytically computing the thermal entropy of a non-interacting Fermi gas in the canonical ensemble based on arbitrary-precision arithmetic, we have used the python library mpmath.

Demo run

To start, try running the following commands to launch a training of 13 spin-polarized electrons in 2D with the dimensionless density parameter 10.0 and (reduced) temperature 0.15 on 8 GPUs:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python main.py --n 13 --dim 2 --rs 10.0 --Theta 0.15 --Emax 25 --sr --batch 4096 --num_devices 8 --acc_steps 2

Note that we effectively sample a batch of totally 8192 samples in each training step. However, such a batch size will result in too large a memory consumption to be accommodated by 8 GPUs. To overcome this problem, we choose to split the batch into two equal pieces, and accumulate the gradient and various observables for each piece in two sequential substeps. In other words, the argument batch in the command above actually stands for the batch per accumulation step.

If you have only, say, 4 GPUs, you can set batch, num_devices, acc_steps to be 2048, 4 and 4 respectively to launch the same training process, at the expense of doubling the running time. The GPU hours are nevertheless the same.

For the detail meaning of other command line arguments, run

python main.py --help

or directly refer to the source code.

Trained model and data

A training process from complete scratch actually contains two stages. In the first stage, a variational autoregressive network is pretrained to approximate the Boltzmann distribution of the corresponding non-interacting electron gas. The resulting model can be saved and then loaded later. In fact, we have provided such a model file for the parameter settings of the last section for your convenience, so you can quickly get a feeling of the second stage of training the truly interacting system of our interest. We encourage you to remove the file to pretrain the model by yourself; it is actually much faster than the training in the second stage.

To facilitate further developments, we also provide the training models and logged data for various calculations in the paper, which are located in the data directory.

To cite

arxiv

You might also like...
Hysterese plugin with two temperature offset areas

craftbeerpi4 plugin OffsetHysterese Temperatur-Steuerungs-Plugin mit zwei tempereaturbereich abhängigen Offsets. Installation sudo pip3 install https:

A web application that provides real time temperature and humidity readings of a house.
A web application that provides real time temperature and humidity readings of a house.

About A web application which provides real time temperature and humidity readings of a house. If you're interested in the data collected so far click

Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Using NumPy to solve the equations of fluid mechanics together with Finite Differences, explicit time stepping and Chorin's Projection methods
Using NumPy to solve the equations of fluid mechanics together with Finite Differences, explicit time stepping and Chorin's Projection methods

Computational Fluid Dynamics in Python Using NumPy to solve the equations of fluid mechanics 🌊 🌊 🌊 together with Finite Differences, explicit time

text_recognition_toolbox: The reimplementation of a series of classical scene text recognition papers with Pytorch in a uniform way.
text_recognition_toolbox: The reimplementation of a series of classical scene text recognition papers with Pytorch in a uniform way.

text recognition toolbox 1. 项目介绍 该项目是基于pytorch深度学习框架,以统一的改写方式实现了以下6篇经典的文字识别论文,论文的详情如下。该项目会持续进行更新,欢迎大家提出问题以及对代码进行贡献。 模型 论文标题 发表年份 模型方法划分 CRNN 《An End-t

Hierarchical Uniform Manifold Approximation and Projection
Hierarchical Uniform Manifold Approximation and Projection

HUMAP Hierarchical Manifold Approximation and Projection (HUMAP) is a technique based on UMAP for hierarchical non-linear dimensionality reduction. HU

Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.
Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.

Nonuniform-to-Uniform Quantization This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quanti

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayesian-Torch is designed to be flexible and seamless in extending a deterministic deep neural network architecture to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers.

FID calculation with proper image resizing and quantization steps
FID calculation with proper image resizing and quantization steps

clean-fid: Fixing Inconsistencies in FID Project | Paper The FID calculation involves many steps that can produce inconsistencies in the final metric.

Comments
  • Unnecessary memory overhead when replicating fisher information matrix

    Unnecessary memory overhead when replicating fisher information matrix

    These two lines replicate the (classical and quantum) fisher information matrices across devices: https://github.com/fermiflow/CoulombGas/blob/master/main.py#L312 https://github.com/fermiflow/CoulombGas/blob/master/main.py#L314

    However, this may incur num_devices*(nparams**2) storage temporally on one device in this line: https://github.com/fermiflow/CoulombGas/blob/master/utils.py#L7

    This will strongly limit the number of parameters one can handle.

    Possible solutions:

    1. find a way to do the replication without num_devices memory overhead,
    2. or, solve sr equation on one device, then broadcast the update to all devices
    opened by wangleiphy 2
  • CI failling due to deprecated `jax.ops.index_update`

    CI failling due to deprecated `jax.ops.index_update`

    currently CI is failing https://github.com/fermiflow/CoulombGas/runs/5611973129?check_suite_focus=true

    This is probably due to an update in jax 0.3.2, https://jax.readthedocs.io/en/latest/changelog.html

    The functions jax.ops.index_update, jax.ops.index_add, which were deprecated in 0.2.22, have been removed. Please use the .at property on JAX arrays instead, e.g., x.at[idx].set(y).

    Please fix.

    opened by wangleiphy 1
  • Adam optimizer errors

    Adam optimizer errors

    Looks like there is a problem with parameter pmap when using adam optimizer

    This (note I have dropped --sr )

    export CUDA_VISIBLE_DEVICES=0,1,2,3
    python main.py --n 13 --dim 2 --rs 10.0 --Theta 0.15 --Emax 25 --batch 4096 --num_devices 4 --acc_steps 2
    

    gives

    ...
    iter: 0001 F: -1.9656119959069707 F_std: 0.002970439785129169 E: -1.9296861940063403 E_std: 0.0029793946046783874 K: 0.2892473978205551 K_std: 0.00013518503689654573 V: -2.2189335918268953 V_std: 0.002970435066943954 S: 5.98763365010506 S_std: 0.022527996149788027 accept_rate: 0.6797216796875
    Traceback (most recent call last):
      File "/home/wanglei/CoulombGas/main.py", line 322, in <module>
        keys, state_indices, x, accept_rate = sample_stateindices_and_x(keys,
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2058, in cache_miss
        out_tree, out_flat = f_pmapped_(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/api.py", line 1934, in f_pmapped
        out = pxla.xla_pmap(
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 1727, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 1652, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 1730, in process
        return trace.process_map(self, fun, tracers, params)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 633, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 766, in xla_pmap_impl
        compiled_fun, fingerprint = parallel_callable(
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/linear_util.py", line 263, in memoized_fun
        ans = call(fun, *args)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 794, in parallel_callable
        pmap_computation = lower_parallel_callable(
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 964, in lower_parallel_callable
        jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 871, in stage_parallel_callable
        jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1566, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1543, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/wanglei/CoulombGas/VMC.py", line 22, in sample_stateindices_and_x
        state_indices = sampler(params_van, key_state, batch)
      File "/home/wanglei/CoulombGas/sampler.py", line 37, in sampler
        logits = jax.vmap(_logits, (None, 0), 0)(params, state_indices)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/api.py", line 1520, in batched_fun
        out_flat = batching.batch(
      File "/home/wanglei/.local/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/wanglei/CoulombGas/sampler.py", line 27, in _logits
        logits = network.apply(params, None, sp_indices[state_idx])
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
        out, state = f.apply(params, {}, *args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 400, in apply_fn
        out = f(*args, **kwargs)
      File "/home/wanglei/CoulombGas/main.py", line 94, in forward_fn
        return model(state_idx)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/wanglei/CoulombGas/autoregressive.py", line 73, in __call__
        x = hk.Linear(self.model_size,
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/basic.py", line 174, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/base.py", line 331, in get_parameter
        raise ValueError(
    jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'transformer/embedding_mlp/w' with retrieved shape (4, 4, 2, 16) does not match shape=[2, 16] dtype=dtype('int64')
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/home/wanglei/CoulombGas/main.py", line 322, in <module>
        keys, state_indices, x, accept_rate = sample_stateindices_and_x(keys,
      File "/home/wanglei/CoulombGas/VMC.py", line 22, in sample_stateindices_and_x
        state_indices = sampler(params_van, key_state, batch)
      File "/home/wanglei/CoulombGas/sampler.py", line 37, in sampler
        logits = jax.vmap(_logits, (None, 0), 0)(params, state_indices)
      File "/home/wanglei/CoulombGas/sampler.py", line 27, in _logits
        logits = network.apply(params, None, sp_indices[state_idx])
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
        out, state = f.apply(params, {}, *args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 400, in apply_fn
        out = f(*args, **kwargs)
      File "/home/wanglei/CoulombGas/main.py", line 94, in forward_fn
        return model(state_idx)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/wanglei/CoulombGas/autoregressive.py", line 73, in __call__
        x = hk.Linear(self.model_size,
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/basic.py", line 174, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/base.py", line 331, in get_parameter
        raise ValueError(
    ValueError: 'transformer/embedding_mlp/w' with retrieved shape (4, 4, 2, 16) does not match shape=[2, 16] dtype=dtype('int64')
    
    opened by wangleiphy 0
Owner
FermiFlow
ab-initio study of fermions at finite temperature
FermiFlow
Clustering with variational Bayes and population Monte Carlo

pypmc pypmc is a python package focusing on adaptive importance sampling. It can be used for integration and sampling from a user-defined target densi

null 45 Feb 6, 2022
Angular & Electron desktop UI framework. Angular components for native looking and behaving macOS desktop UI (Electron/Web)

Angular Desktop UI This is a collection for native desktop like user interface components in Angular, especially useful for Electron apps. It starts w

Marc J. Schmidt 49 Dec 22, 2022
This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

Quinn Herden 1 Feb 4, 2022
Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms

AdvancedHMC.jl AdvancedHMC.jl provides a robust, modular and efficient implementation of advanced HMC algorithms. An illustrative example for Advanced

The Turing Language 167 Jan 1, 2023
An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects different compression algorithms have.

ImageCompressionSimulation An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects o

James Park 1 Dec 11, 2021
Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo"

dblmahmc Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo" Requirements: https://github.com

null 1 Dec 17, 2021
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
Code for "Learning Canonical Representations for Scene Graph to Image Generation", Herzig & Bar et al., ECCV2020

Learning Canonical Representations for Scene Graph to Image Generation (ECCV 2020) Roei Herzig*, Amir Bar*, Huijuan Xu, Gal Chechik, Trevor Darrell, A

roei_herzig 24 Jul 7, 2022
Canonical Appearance Transformations

CAT-Net: Learning Canonical Appearance Transformations Code to accompany our paper "How to Train a CAT: Learning Canonical Appearance Transformations

STARS Laboratory 54 Dec 24, 2022