QDax: Accelerated Quality-Diversity
QDax is a tool to accelerate Quality-Diveristy (QD) algorithms through hardware accelerators and massive parallelism.
QDax paper: https://arxiv.org/abs/2202.01258
Installation
Dependencies
In particular, QDax relies on the JAX and brax libraries. To install all dependencies, you can run the following command:
pip install -r requirements.txt
Installing QDax
pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git
Examples
There are two ways to run QDax:
-
Colab Notebooks (has visualization included) - recommended (to also avoid needing to download dependencies and configure environment) Open the notebook notebook in the notebooks directory and run it according the walkthrough instructions.
-
Locally - A singularity folder is provided to easily install everything in a container. If you use singularity image or install the dependencies locally, you can run a single experiment using for example:
python run_qd.py --env_name walker --grid_shape 30 30 --batch_size 2048 --num-evaluations 1000000
Alternatively, to run experiments that compare the effect of batch sizes, use command below. For example, to run the experiments on the walker environment (which has a 2-dimensional BD) with a grid shape of (30,30) with 5 replications.
python3 run_comparison_batch_sizes.py --env_name walker --grid_shape 30 30 -n 5
CUDA_VISIBLE_DEVICES=0 python3 run_comparison_batch_sizes.py --env_name walker --grid_shape 30 30 -n 5
CUDA_VISIBLE_DEVICES="0,1" python3 run_comparison_batch_sizes.py --env_name walker --grid_shape 30 30 -n 5
Analysis and Plotting Tools
Expname is the name of the directories of the experiments (it will look for directory that start with that string. Results is the directory containing all the results folders.
python3 analysis/plot_metrics.py --exp-name qdax_training --results ./qdax_walker_fixednumevals/ --attribute population_size --save figure.png
where:
--exp-name
is the name of the directories of the experiments (it will look for directory that starts with that string.--results
is the directory containing all the results folders.--attribute
: attribute in which we want to compare the results on.
Code Structure (for developers)
Some things to note beforehand is that JAX relies on a functional programming paradigm. We will try as much as possible to maintain this programming style.
The main file used is qdax/training/qd.py
. This file contains the main train
function which consists of the entire QD loop and supporting functions.
- Inputs: The
train
function takes as input the task, emitter and hyperparameters. - Functions: The main functions used by
train
are also declared in this file. Working in top_down importance in terms of how the code works. The key function here is the_es_one_epoch
function. In terms of QD, this determines the loop performed at each generation: (1) Selection (from archive) and Variation to generate solutions to be evaluated defined by theemitter_fn
, (2) Evaluation and (3) Archive Update defined by (eval_and_add_fn
). The first part of thetrain
function is theinit_phase_fn
which initializes the archive using random policies. - Flow:
train
first callsinit_phase_fn
and then_es_one_epoch
for a defined number of generations or evaluations.
Notes
Key Management
key = jax.random.PRNGKey(seed)
key, key_model, key_env = jax.random.split(key, 3)
key
is for training_state.keykey_model
is for policy_model.initkey_env
is for environment initialisations (although in our deterministic case we do not really use this)
From the flow of the program, we perform an init_phase
first. The init_phase
function uses the training_state.key
and outputs the updated training_state
(with a new key) after performing the initialization (initialization of archive by evaluating random policies).
After this, we depend on the training_state.key
in es_one_epoch
to be managed. In the es_one_epoch(training_state)
:
key, key_emitter, key_es_eval = jax.random.split(training_state.key, 3)
key_selection
passed into selection functionkey_petr
is passed into mutation function (iso_dd)key_es_eval
is passed intoeval_and_add
key
is saved as the newtraining_state.key
for the next epoch. And thetraining_state
is returned as an output of this function.
Contributors
QDax is currently developed and maintained by the Adaptive & Intelligent Robotics Lab (AIRL):