EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit

Related tags

Deep Learning evojax
Overview

EvoJAX: Hardware-Accelerated Neuroevolution

EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs. EvoJAX achieves very high performance by implementing the evolution algorithm, neural network and task all in NumPy, which is compiled just-in-time to run on accelerators.

This repo also includes several extensible examples of EvoJAX for a wide range of tasks, including supervised learning, reinforcement learning and generative art, demonstrating how EvoJAX can run your evolution experiments within minutes on a single accelerator, compared to hours or days when using CPUs.

EvoJAX paper: https://arxiv.org/abs/2202.05008

Installation

EvoJAX is implemented in JAX which needs to be installed first.

Install JAX: Please first follow JAX's installation instruction with optional GPU/TPU backend support. In case JAX is not set up, EvoJAX installation will still try pulling a CPU-only version of JAX. Note that Colab runtimes come with JAX pre-installed.

Install EvoJAX:

# Install from PyPI.
pip install evojax

# Or, install from our GitHub repo.
pip install git+https://github.com/google/evojax.git@main

Code Overview

EvoJAX is a framework with three major components, which we expect the users to extend.

  1. Neuroevolution Algorithms All neuroevolution algorithms should implement the evojax.algo.base.NEAlgorithm interface and reside in evojax/algo/. We currently provide PGPE, with more coming soon.
  2. Policy Networks All neural networks should implement the evojax.policy.base.PolicyNetwork interface and be saved in evojax/policy/. In this repo, we give example implementations of the MLP, ConvNet, Seq2Seq and PermutationInvariant models.
  3. Tasks All tasks should implement evojax.task.base.VectorizedTask and be in evojax/task/.

These components can be used either independently, or orchestrated by evojax.trainer and evojax.sim_mgr that manage the training pipeline. While they should be sufficient for the currently provided policies and tasks, we plan to extend their functionality in the future as the need arises.

Examples

As a quickstart, we provide non-trivial examples (scripts in examples/ and notebooks in examples/notebooks) to illustrate the usage of EvoJAX. We provide example commands to start the training process at the top of each script. These scripts and notebooks are run with TPUs and/or NVIDIA V100 GPU(s):

Supervised Learning Tasks

While one would obviously use gradient-descent for such tasks in practice, the point is to show that neuroevolution can also solve them to some degree of accuracy within a short amount of time, which will be useful when these models are adapted within a more complicated task where gradient-based approaches may not work.

  • MNIST Classification - We show that EvoJAX trains a ConvNet policy to achieve >98% test accuracy within 5 min on a single GPU.
  • Seq2Seq Learning - We demonstrate that EvoJAX is capable of learning a large network with hundreds of thousands parameters to accomplish a seq2seq task.

Classic Control Tasks

The purpose of including control tasks are two-fold: 1) Unlike supervised learning tasks, control tasks in EvoJAX have undetermined number of steps, we thus use these examples to demonstrate the efficiency of our task roll-out loops. 2) We wish to show the speed-up benefit of implementing tasks in JAX and illustrate how to implement one from scratch.

  • Locomotion - Brax is a differentiable physics engine implemented in JAX. We wrap it as a task and train with EvoJAX on GPUs/TPUs. It takes EvoJAX tens of minutes to solve a locomotion task in Brax.
  • Cart-Pole Swing Up - We illustrate how the classic control task can be implemented in JAX and be integrated into EvoJAX's pipeline for significant speed up training.

Novel Tasks

In this last category, we go beyond simple illustrations and show examples of novel tasks that are more practical and attractive to researchers in the genetic and evolutionary computation area, with the goal of helping them try out ideas in EvoJAX.

Multi-agent WaterWorld ES-CLIP: “A drawing of a cat”
  • WaterWorld - In this task, an agent tries to get as much food as possible while avoiding poisons. EvoJAX is able to learn the agent in tens of minutes on a single GPU. Moreover, we demonstrate that multi-agents training in EvoJAX is possible, which is beneficial for learning policies that can deal with environmental complexity and uncertainties.
  • Abstract Paintings (notebook 1 and notebook 2) - We reproduce the results from this computational creativity work and show how the original work, whose implementation requires multiple CPUs and GPUs, could be accelerated on a single GPU efficiently using EvoJAX, which was not possible before. Moreover, with multiple GPUs/TPUs, EvoJAX can further speed up the mentioned work almost linearly. We also show that the modular design of EvoJAX allows its components to be used independently -- in this case it is possible to use only the ES algorithms from EvoJAX while leveraging one's own training loops and environment implantation.

Disclaimer

This is not an official Google product.

Comments
  • Some proposals about the `Trainer` logic

    Some proposals about the `Trainer` logic

    Currently I see two ways of using the Trainer.test_task:

    1. The test_task of the trainer is used for validation. The actual test set is being holdout and not seen during training or validation. In this case, how do I run the actual test? I can't pass just the test_task to the trainer, because the train_task is non-optional. Looks like there should be a way to do this with evojax.
    2. The test_task of the trainer is used for the actual test, no validation is used at all. In this case, why does the trainer.run return the best model score and not the last model score?

    I propose the following (high level) logic:

    best_val_reward = trainer.fit(train_task: VectorizedTask, val_task: Optional[VectorizedTask] = None)  # maybe the user doesn't want validation (e.g. train on latest data without early stopping)
    test_reward = trainer.test(test_task: VectorizedTask, checkpoint="best|last|path")  # specify which checkpoint to use for testing
    

    Probably early stopping would be pretty necessary for the trainer.fit method. Currently there is no way to determine when to do it and even which model iteration has the best result.

    I'm willing to implement this logic in a PR.

    opened by danielgafni 7
  • high dimensional parametric search

    high dimensional parametric search

    I'm trying to use evojax to evolve my model parameters. I found that the algorithm only accepts the parameter num_dims as the dimension, whether it can only be int type here? If I want to evolve multidimensional parameters, such as [1000x1000] data, how can I do it? Thanks!

    opened by Agnes233 5
  • add CR-FM-NES algorithm

    add CR-FM-NES algorithm

    Adds a wrapper to CR-FM-NES, see "Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES)" pdf .

    It wraps the fcmaes Eigen/C++ version of CR-FM-NES which is derived from https://github.com/nomuramasahir0/crfmnes.

    Since there are numpy and Eigen based implementations (and soon a JAX based one) of CR-FM-NES available, it will be possible to compare the performance of these tree "backends" for the same algorithm. This commit wraps only the C++/Eigen based implementation crfmnes.cpp .

    Tested on NVIDIA 3090 + AMD 5950x Linux Mint 20 (Ubuntu based). Performance (wall time) is similar to PGPE outperforming CMA_ES_JAX. Benchmark results for waterworld are above all other algorithms. Do "pip install fcmaes --upgrade" before testing.

    opened by dietmarwo 5
  • Evaluating brax environments other than brax-ant. Terminates with error.

    Evaluating brax environments other than brax-ant. Terminates with error.

    Information

    Issue is with running brax environments other brax-ant. The included humanoid, half cheetah and fetch environments are affected.

    Couldn't find any references to this issue in the repo. I could have missed something.

    Expected Behavior

    /home/<USER>/anaconda3/envs/evojax/bin/python /home/<USER>/evojax/scripts/benchmarks/train.py -config configs/PGPE/brax_halfcheetah.yaml
    brax: 2022-06-16 20:41:01,954 [INFO] EvoJAX brax
    brax: 2022-06-16 20:41:01,954 [INFO] ==============================
    absl: 2022-06-16 20:41:02,137 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
    absl: 2022-06-16 20:41:02,221 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
    MLPPolicy: 2022-06-16 20:41:03,747 [INFO] MLPPolicy.num_params = 3974
    brax: 2022-06-16 20:41:03,787 [INFO] use_for_loop=False
    brax: 2022-06-16 20:41:03,825 [INFO] Start to train for 1 iterations.
    brax: 2022-06-16 20:41:56,024 [INFO] [TEST] Iter=1, #tests=1, max=-9.7476, avg=-9.7476, min=-9.7476, std=0.0000
    brax: 2022-06-16 20:41:56,087 [INFO] Training done, best_score=-9.7476
    brax: 2022-06-16 20:41:56,093 [INFO] Loaded model parameters from ./log/PGPE/brax/default.
    brax: 2022-06-16 20:41:56,093 [INFO] Start to test the parameters.
    brax: 2022-06-16 20:42:03,478 [INFO] [TEST] #tests=1, max=-9.9009, avg=-9.9009, min=-9.9009, std=0.0000
    

    Current Behavior

    brax: 2022-06-16 20:26:04,657 [INFO] EvoJAX brax
    brax: 2022-06-16 20:26:04,657 [INFO] ==============================
    absl: 2022-06-16 20:26:04,833 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
    absl: 2022-06-16 20:26:04,920 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
    MLPPolicy: 2022-06-16 20:26:06,465 [INFO] MLPPolicy.num_params = 3974
    brax: 2022-06-16 20:26:06,504 [INFO] use_for_loop=False
    brax: 2022-06-16 20:26:06,541 [INFO] Start to train for 10 iterations.
    Traceback (most recent call last):
      File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 88, in <module>
        main(config)
      File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 64, in main
        trainer.run(demo_mode=False)
      File "/home/<USER>/evojax/evojax/trainer.py", line 152, in run
        scores, bds = self.sim_mgr.eval_params(
      File "/home/<USER>/evojax/evojax/sim_mgr.py", line 258, in eval_params
        return self._scan_loop_eval(params, test)
      File "/home/<USER>/evojax/evojax/sim_mgr.py", line 355, in _scan_loop_eval
        scores, all_obs, masks, final_states = rollout_func(
      File "/home/<USER>/evojax/evojax/sim_mgr.py", line 202, in rollout
        (obs_set, obs_mask)) = jax.lax.scan(
      File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1630, in scan
        _check_tree_and_avals("scan carry output and input",
      File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2316, in _check_tree_and_avals
        raise TypeError(f"{what} must have identical types, got\n{diff}.")
    jax._src.traceback_util.UnfilteredStackTrace: TypeError: scan carry output and input must have identical types, got
    (State(state=State(qp=QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), obs='ShapedArray(float32[16384,18])', reward='ShapedArray(float32[16384])', done='ShapedArray(float32[16384])', metrics={'reward_ctrl_cost': 'ShapedArray(float32[16384])', 'reward_forward': 'ShapedArray(float32[16384])'}, info={'first_obs': 'ShapedArray(float32[16384,18])', 'first_qp': QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), 'steps': 'ShapedArray(float32[16384])', 'truncation': 'ShapedArray(float32[16384])'}), obs='ShapedArray(float32[16384,18])', feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])'), PolicyState(keys='ShapedArray(uint32[16384,2])'), 'ShapedArray(float32[16384,3974])', 'ShapedArray(float32[37])', 'ShapedArray(float32[16384])', 'ShapedArray(float32[16384])').
    

    Exact Error:

    feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])')
    

    Failure Information

    Context

    Based on commit history, this appears to be due to the changes introduced in #33 . Manually altering variable feet_contact variable from method reset_fn in file evojax/evojax/task/brax_task.py allows for the other environments to be run.

    Setup details related to the hardware are irrelevant since error occurs on the hosted colab notebook as well.

    brax                         0.0.13
    evojax                       0.2.11               
    flax                         0.4.0
    jax                          0.3.1
    jaxlib                       0.3.0+cuda11.cudnn82
    

    Steps to Reproduce

    Please provide detailed steps for reproducing the issue.

    1. Run evojax/scripts/benchmarks/train.py using a modified evojax/scripts/benchmarks/configs/<ES> file using non-ant brax environment.
    2. Modify feet_contact array size and test.
    opened by Surya-77 5
  • AssertionError for OpenES

    AssertionError for OpenES

    When I try to instantiate OpenES from open_es.py, I get the following error message: Schermata 2022-12-15 alle 20 23 59 I traced back the problem to line 110 in open_es.py, where both centered_rank and z_score arguments are set to True: Schermata 2022-12-15 alle 20 26 01 But line 26 of FitnessShaper class from evosax/utils/reshape_fitness.py says that Schermata 2022-12-15 alle 20 26 49 How to get around this issue?

    opened by pigozzif 4
  • Native implementation in JAX of Augmented Random Search

    Native implementation in JAX of Augmented Random Search

    Test results

    Note for MNIST I halved the batch size and doubled the iterations due to memory issues. | | Benchmark | Params | Results (avg.) | | ----------------|-----------------------|----------|------------- | | CartPole (easy) | 900 (max_iter=1000) | Link| 910 | | CartPole (hard) | 600 (max_iter=2000) | Link | 558.02 | | MNIST | 0.90 (max_iter=4000) | Link | 0.92 | | Brax Ant | 3000 (max_iter=700) | Link | 4129.83 | | Waterworld | 6 (max_iter=2000) |Link | 7.29 | | Waterworld (MA) | 2 (max_iter=2000) | Link | 1.68 |

    opened by EdoardoPona 4
  • AbstractPainting02.ipynb. doesn't work on colab

    AbstractPainting02.ipynb. doesn't work on colab

    Hello, this is a really great code.

    I was able to run "Abstract Painting 01" very well at Google coab. However, when I ran "AbstractPainting02", an error occurred.

    Exception                                 Traceback (most recent call last)
    [<ipython-input-20-b16203d22159>](https://localhost:8080/#) in <module>()
          2 devices = jax.local_devices()
          3 
    ----> 4 image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load('ViT-B/32', "cpu")
          5 
          6 target_text_ids = jnp.array(clip_jax.tokenize([prompt])) # already with batch dim
    
    3 frames
    [/content/CLIP_JAX/clip_jax/clip.py](https://localhost:8080/#) in process_node(value, name)
        117             new_tensor = jnp.array(pytorch_tensor)
        118         else:
    --> 119             raise Exception("not implemented")
        120 
        121         assert new_tensor.shape == value.shape
    
    Exception: not implemented
    

    Which version of clip_jax when you made?

    Best

    opened by shi3z 4
  • Evosax - Sep-CMA-ES

    Evosax - Sep-CMA-ES

    • Reference: Ros & Hansen (2008)
    • evosax Source Code: https://github.com/RobertTLange/evosax/blob/main/evosax/strategies/sep_cma_es.py
    • This PR adds a CMA-ES version which imposes a diagonal structure for the estimated covariance matrix. Thereby it is a lot more memory efficient as compared to pure CMA-ES, which has do store a (d x d) matrix.
    • Benchmarks and hyperparameters:

    | | Benchmarks | Parameters | Results (Avg) | |---|---|---|---| CartPole (easy) | 900 (max_iter=1000)|Link| 924.3028 | CartPole (hard) | 600 (max_iter=1000)|Link| 626.9728 | MNIST | 90.0 (max_iter=2000) | Link| 0.9545 | Brax Ant | 3000 (max_iter=300) |Link| 3980.9194 | Waterworld | 6 (max_iter=500) | Link| 9.9000 | Waterworld (MA) | 2 (max_iter=2000) | Link | 1.1875 |

    Note: Linting doesn't pass due to import error for Open_ES - see PR #19. This has to be merged first.

    opened by RobertTLange 3
  • Adding a Linear Policy

    Adding a Linear Policy

    This is a simple linear policy (1 layer neural network). This policy is especially useful for tasks related to control, with for example augmented random search. In fact, in the original ARS paper, one of the algorithm's key advantages is the ability to find high performing linear policies.

    I created a new policy rather than editing MLP for simplicity, and since they would most likely be used in different contexts (eg. tasks, algorithms)

    opened by EdoardoPona 2
  • Add a Python/JAX port of CR-FM-NES

    Add a Python/JAX port of CR-FM-NES

    This PR adds a Python/JAX port of Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES), see https://arxiv.org/abs/2201.11422 . Derived from https://github.com/nomuramasahir0/crfmnes.

    This variant is slightly faster than FCRFMC (the C++ port) on fast GPUs/TPUs, but slower on CPUs and for smaller dimensions. It uses 32 bit accuracy (FCRFMC uses 64 bit) which mostly doesn't harm the convergence (with Waterworld MA being the exception for very high iteration numbers).

    Wall time and convergence is mostly comparable with PGPE (as FCRFMC) for the benchmarks. Slower in the beginning, but improving at higher iterations.

    Since there are no for-loops I found no beneficial applications of 'jax.jit', just converted most 'np.arrays' into 'jnp.arrays' deployed on the GPUs/TPUs.

    def sort_indices_by(evals: np.ndarray, z: jnp.ndarray) -> jnp.ndarray:

    uses not evals: jnp.ndarray because this slowed things down on my NVIDIA 3090.

    Since this is Python code, no missing shared libraries on Ubuntu 18 this time.

    Added test results for CRFMNES (this Python implementation) at EvoJax.adoc.

    opened by dietmarwo 2
  • Reproducing benchmark scores

    Reproducing benchmark scores

    Hello everyone.

    I am currently currently trying to reproduce scores from the benchmarks, specifically for ARS, as I am implementing my own version native in jax, and wanted to compare with the wrapper already implemented.

    For example, I cannot achieve the score posted in the benchmark table (902.107) for ARS on cartpole_easy.

    running python train.py -config configs/ARS/cartpole_easy.yaml yields the following training logs

    cartpole_easy: 2022-09-25 22:45:55,777 [INFO] EvoJAX cartpole_easy
    cartpole_easy: 2022-09-25 22:45:55,777 [INFO] ==============================
    absl: 2022-09-25 22:45:55,791 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
    absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
    absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
    MLPPolicy: 2022-09-25 22:45:59,165 [INFO] MLPPolicy.num_params = 4609
    cartpole_easy: 2022-09-25 22:45:59,429 [INFO] use_for_loop=False
    cartpole_easy: 2022-09-25 22:45:59,496 [INFO] Start to train for 1000 iterations.
    cartpole_easy: 2022-09-25 22:46:10,527 [INFO] Iter=50, size=100, max=399.5886, avg=207.9111, min=0.5843, std=99.0207
    cartpole_easy: 2022-09-25 22:46:19,916 [INFO] Iter=100, size=100, max=543.8907, avg=364.9780, min=28.8478, std=141.8982
    cartpole_easy: 2022-09-25 22:46:21,143 [INFO] [TEST] Iter=100, #tests=100, max=553.4018 avg=510.5583, min=462.4243, std=15.6930
    cartpole_easy: 2022-09-25 22:46:30,627 [INFO] Iter=150, size=100, max=558.2020, avg=314.9279, min=89.8001, std=153.6488
    cartpole_easy: 2022-09-25 22:46:40,068 [INFO] Iter=200, size=100, max=562.4118, avg=354.9529, min=47.0048, std=154.1567
    cartpole_easy: 2022-09-25 22:46:40,114 [INFO] [TEST] Iter=200, #tests=100, max=570.1135 avg=547.5375, min=508.5795, std=10.0840
    cartpole_easy: 2022-09-25 22:46:49,579 [INFO] Iter=250, size=100, max=562.1505, avg=325.3990, min=73.3733, std=161.9460
    cartpole_easy: 2022-09-25 22:46:59,073 [INFO] Iter=300, size=100, max=569.5461, avg=370.2641, min=83.7473, std=166.8020
    cartpole_easy: 2022-09-25 22:46:59,129 [INFO] [TEST] Iter=300, #tests=100, max=573.5941 avg=545.0388, min=505.8637, std=11.3853
    cartpole_easy: 2022-09-25 22:47:08,623 [INFO] Iter=350, size=100, max=579.3894, avg=425.6462, min=82.4907, std=126.6614
    cartpole_easy: 2022-09-25 22:47:18,109 [INFO] Iter=400, size=100, max=627.6509, avg=530.2781, min=156.4797, std=76.0956
    cartpole_easy: 2022-09-25 22:47:18,160 [INFO] [TEST] Iter=400, #tests=100, max=639.7323 avg=600.9105, min=573.7767, std=10.7564
    cartpole_easy: 2022-09-25 22:47:27,653 [INFO] Iter=450, size=100, max=668.2064, avg=546.0261, min=418.5385, std=60.5854
    cartpole_easy: 2022-09-25 22:47:37,149 [INFO] Iter=500, size=100, max=684.4142, avg=574.4891, min=446.3126, std=62.5338
    cartpole_easy: 2022-09-25 22:47:37,202 [INFO] [TEST] Iter=500, #tests=100, max=693.1522 avg=682.7945, min=638.0387, std=12.1575
    cartpole_easy: 2022-09-25 22:47:46,708 [INFO] Iter=550, size=100, max=708.9561, avg=591.0547, min=295.5651, std=73.6026
    cartpole_easy: 2022-09-25 22:47:56,212 [INFO] Iter=600, size=100, max=706.8138, avg=599.4783, min=348.7581, std=55.6310
    cartpole_easy: 2022-09-25 22:47:56,263 [INFO] [TEST] Iter=600, #tests=100, max=691.0123 avg=680.4677, min=630.2983, std=6.1448
    cartpole_easy: 2022-09-25 22:48:05,770 [INFO] Iter=650, size=100, max=707.0887, avg=581.3851, min=418.2251, std=75.9066
    cartpole_easy: 2022-09-25 22:48:15,275 [INFO] Iter=700, size=100, max=712.7586, avg=586.4597, min=362.7628, std=71.5669
    cartpole_easy: 2022-09-25 22:48:15,326 [INFO] [TEST] Iter=700, #tests=100, max=725.2336 avg=714.1309, min=635.7863, std=9.3471
    cartpole_easy: 2022-09-25 22:48:24,849 [INFO] Iter=750, size=100, max=716.1056, avg=602.7747, min=458.0401, std=63.1697
    cartpole_easy: 2022-09-25 22:48:34,365 [INFO] Iter=800, size=100, max=709.3475, avg=587.9896, min=393.0367, std=69.2385
    cartpole_easy: 2022-09-25 22:48:34,418 [INFO] [TEST] Iter=800, #tests=100, max=732.5553 avg=720.5952, min=648.5032, std=8.3936
    cartpole_easy: 2022-09-25 22:48:43,945 [INFO] Iter=850, size=100, max=706.8488, avg=598.3582, min=321.8640, std=75.2542
    cartpole_easy: 2022-09-25 22:48:53,482 [INFO] Iter=900, size=100, max=720.0320, avg=596.1929, min=370.6555, std=77.2801
    cartpole_easy: 2022-09-25 22:48:53,536 [INFO] [TEST] Iter=900, #tests=100, max=703.5345 avg=692.9500, min=677.6909, std=5.9381
    cartpole_easy: 2022-09-25 22:49:03,068 [INFO] Iter=950, size=100, max=716.2341, avg=598.3802, min=422.7760, std=71.7756
    cartpole_easy: 2022-09-25 22:49:12,455 [INFO] [TEST] Iter=1000, #tests=100, max=726.0114, avg=719.0803, min=698.4325, std=4.7247
    cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952
    cartpole_easy: 2022-09-25 22:49:12,458 [INFO] Loaded model parameters from ./log/ARS/cartpole_easy/default.
    cartpole_easy: 2022-09-25 22:49:12,459 [INFO] Start to test the parameters.
    cartpole_easy: 2022-09-25 22:49:12,509 [INFO] [TEST] #tests=100, max=728.9848, avg=720.6152, min=698.9832, std=5.0566
    

    I am not entirely sure if the result on the benchmark table is intended to be 720.5952 from cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952

    or the max score from the final test. Regardless, neither of these match the one posted on the benchmark table.

    Am I doing something wrong to reproduce these scores? This makes me unable to compare my own implementation of the algorithm.

    Thank you

    opened by EdoardoPona 2
  • Add Diversifier QD Meta Algorithm - JAX backend

    Add Diversifier QD Meta Algorithm - JAX backend

    This PR adds a new JAX-based QD meta algorithm called Diversifier. It is a generalization of CMA-ME.

    It uses a MAP-Elites archive not for solution candidate generation, but only to modify the fitness values told (via tell) to the wrapped algorithm. This modification changes the fitness ranking of the population to favor exploration over exploitation. Tested with CR-FM-NES and CMA-ES, but other wrapped algorithms may work as well. Based on fcmaes diversifier.py (see MapElites.adoc).

    The generalization over CMA-ME is necessary in the EvoJAX context, because CMA-ES struggles with a very high number of decision variables. Therefore CR-FM-NES-ME is superior here - as possibly are other not yet tested alternatives.

    https://doi.org/10.1145/2739480.2754664 proposes the QD score (sum of fitness values of all elites in the map) as metric for comparison.

    For Brax-Ant CR-FM-NES-ME (Diversifier applied to CR-FM-NES), compared with MAP-Elites, reaches a higher QD-score for high iteration numbers (see details below). So MAP-Elites should only be preferred for a low evaluation budget or if you want to maximize the number of occupied niches.

    On a NVIDIA 3090 + AMD 5950, Linux Mint with optimized configurations we measured:

    • MAP-Elites has a the same optimizer overhead (evaluation/sec rate for the same popsize).
    • MAP-Elites has a higher number of occupied niches.

    but

    • CR-FM-NES-ME has a much higher QD score and found a better global optimum for a high evaluation budget.

    Detailed measurements for the Brax-Ant example (NVIDIA 3090 + AMD 5950, Linux Mint):

    After 20 minutes MAP-Elites is in the lead, but slows down from there. CR_FM_NES-ME continues to improve until 500 minutes / 8 million evaluations. CR_FM_NES-ME can even produce a good global optimum - 4107 - thereby still occupying 6138 niches with a mean score of 1208. After 500 minutes MAP-Elites continues to improve where CR_FM_NES-ME does not, but at that time CR_FM_NES-ME has a >70% lead in score.

    CR_FM_NES-ME with init-std = 0.159, popsize = 512, fitness_weight 0.0

    20 min QD score: 1692282 occupied: 4936 max score: 558 mean score: 342 evaluations: 263680 50 min QD score: 2724260 occupied: 5628 max score: 918 mean score: 484 evaluations: 704512 100 min QD score: 4289807 occupied: 6087 max score: 1442 mean score: 704 evaluations: 1496576 200 min QD score: 5928753 occupied: 6138 max score: 2363 mean score: 965 evaluations: 3072000 300 min QD score: 6524518 occupied: 6138 max score: 2862 mean score: 1063 evaluations: 4710400 400 min QD score: 7353257 occupied: 6138 max score: 3889 mean score: 1198 evaluations: 6348800 500 min QD score: 7418018 occupied: 6138 max score: 4107 mean score: 1208 evaluations: 7884800 600 min QD score: 7444092 occupied: 6138 max score: 4211 mean score: 1212 evaluations: 9523200

    MAP-Elites iso-sigma = 0.05, line-sigma = 0.2, popsize = 1024: (line-sigma = 0.3 is worse)

    20 min QD score: 2509773 occupied: 5621 max score: 643 mean score: 446 evaluations: 346112 50 min QD score: 3022521 occupied: 6375 max score: 724 mean score: 474 evaluations: 915456 100 min QD score: 3383041 occupied: 6786 max score: 769 mean score: 498 evaluations: 1941504 200 min QD score: 3713977 occupied: 7107 max score: 825 mean score: 522 evaluations: 3936256 300 min QD score: 3915492 occupied: 7265 max score: 927 mean score: 538 evaluations: 5922816 400 min QD score: 4065677 occupied: 7400 max score: 927 mean score: 549 evaluations: 7941120 500 min QD score: 4179020 occupied: 7498 max score: 927 mean score: 557 evaluations: 9958400 600 min QD score: 4272665 occupied: 7566 max score: 927 mean score: 564 evaluations: 12083200 700 min QD score: 4351397 occupied: 7632 max score: 941 mean score: 570 evaluations: 14094336 800 min QD score: 4415351 occupied: 7675 max score: 1003 mean score: 575 evaluations: 16040960

    These results indicate that it should be possible to apply MAP-Elites to the resulting CR_FM_NES-ME archive to further improve occupancy and score. As algorithm wrapped by Diversifier,py CRFMNES can be replaced by FCRFMC (same algorithm but implemented in C++). We got the same results, but this may reduce the GPU load for smaller GPUs/TPUs and is definively advantageous for CPU alone executions. On the Nvidia 3090 CRFMNES is slightly faster.

    Note that 'fitness_weight' is a concept neither used in CMA-ME nor in fcmaes fcmaes diversifier. All these use implicitely fitness_weight=0. For fcmaes the reason is that there are other means to improve the elites of a given map, so the focus is on exploration here. We use as default fitness_weight=0, because for Brax Ant the final QD score is higher - but the final global optimum found is lower.

    fcmaes even supports sequences of wrapped algorithms, something probably not relevant for EvoJAX.

    Increasing popsize to 1024 closes the evaluations / sec gap to MAP-Elites, the rate is 34% higher than with popsize = 512. But popsize = 1024 seems to produce lower occupancy - which is quite suprising:

    CR_FM_NES-ME with init-std = 0.159, popsize = 1024, fitness_weight 0.0

    20 min QD score: 1864258 occupied: 4856 max score: 538 mean score: 383 evaluations: 350208 50 min QD score: 2702674 occupied: 5402 max score: 848 mean score: 500 evaluations: 905216 100 min QD score: 3853807 occupied: 5873 max score: 1288 mean score: 656 evaluations: 1879040 200 min QD score: 5005292 occupied: 5947 max score: 1781 mean score: 841 evaluations: 3891200 300 min QD score: 6425120 occupied: 5963 max score: 2936 mean score: 1077 evaluations: 5963776 400 min QD score: 7103424 occupied: 5976 max score: 3783 mean score: 1188 evaluations: 8192000 500 min QD score: 7282457 occupied: 5980 max score: 4111 mean score: 1217 evaluations: 10240000 600 min QD score: 7371868 occupied: 5982 max score: 4227 mean score: 1232 evaluations: 12288000 700 min QD score: 7405531 occupied: 5982 max score: 4276 mean score: 1237 evaluations: 14336000 800 min QD score: 7464371 occupied: 5983 max score: 4307 mean score: 1247 evaluations: 16384000 900 min QD score: 7509290 occupied: 5989 max score: 4325 mean score: 1253 evaluations: 18432000 1000min QD score: 7514184 occupied: 5989 max score: 4342 mean score: 1254 evaluations: 20480000

    But why can't we have our cake and eat it too?

    This is not part of the PR but discusses what could be done in the future:

    Both Diversifier and MAP-Elites share the same archive management. They differ only in population generation. In the future both could be unified into a single MD solver - still called MAP-Elites. This new implementation could randomly chose the way "ask" works. We define a probability, a wrapped solver is used instead of the standard mechanism. If this probability is 0, we have the old MAP-Elites. If it is 1.0, we have Diversifier. The interesting question is: What happens for values in between? Lets try 0.5. This can easily be implemented as:

        def ask(self) -> jnp.ndarray:
            self.key, key = jax.random.split(self.key)
            if jax.random.uniform(key) > 0.5: # a parameter to play with
                self.population = self.solver.ask() # population from wrapped solver
                self.solver_asked = True
            else: # population from MA-Elites generator
                self.key, mutate_key, parents = self._sample_parents(
                                    key=self.key,
                                    occupancy=self.occupancy_lattice,
                                    params=self.params_lattice)      
                self.population = self._gen_pop(parents, mutate_key)
                self.solver_asked = False
            return self.population
    
        def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:      
            if self.solver_asked:   
                lattice_fitness = self.fitness_lattice[self.bin_idx]
                to_tell = self._get_to_tell(fitness, lattice_fitness, self.fitness_weight)
                self.solver.tell(to_tell)
            # update lattice 
    

    MAP-Elites + CR_FM_NES-ME with iso-sigma = 0.05, line-sigma = 0.2, init-std = 0.159, popsize = 1024, fitness_weight 0.0

    20 min QD score: 2201387 occupied: 5266 max score: 672 mean score: 418 evaluations: 344064 50 min QD score: 2738691 occupied: 6105 max score: 672 mean score: 448 evaluations: 892928 100 min QD score: 3348423 occupied: 6656 max score: 857 mean score: 503 evaluations: 1859584 200 min QD score: 4233393 occupied: 7135 max score: 1103 mean score: 593 evaluations: 3851264 300 min QD score: 5139277 occupied: 7334 max score: 1586 mean score: 700 evaluations: 5893120 400 min QD score: 5776929 occupied: 7457 max score: 1884 mean score: 774 evaluations: 7943168 500 min QD score: 6098261 occupied: 7537 max score: 2104 mean score: 809 evaluations: 9947136 600 min QD score: 7240351 occupied: 7603 max score: 2890 mean score: 952 evaluations: 11999232 700 min QD score: 7800357 occupied: 7660 max score: 3421 mean score: 1018 evaluations: 14023680 800 min QD score: 8004109 occupied: 7699 max score: 3735 mean score: 1039 evaluations: 16000000

    This is a 81% QD score increase compared to MAP-Elites alone thereby also improving occupancy.

    900 min QD score: 8115904 occupied: 7744 max score: 3917 mean score: 1048 evaluations: 18140160 1000min QD score: 8195842 occupied: 7772 max score: 4019 mean score: 1054 evaluations: 20133888 1100min QD score: 8259249 occupied: 7799 max score: 4090 mean score: 1059 evaluations: 22155264 1200min QD score: 8319027 occupied: 7826 max score: 4130 mean score: 1062 evaluations: 24177664 1300min QD score: 8362034 occupied: 7847 max score: 4156 mean score: 1065 evaluations: 26213376

    QD-score 8362034 probably is a challenge for each algorithm independent from the evaluation budget.

    opened by dietmarwo 1
  • Bug of center_lr_decay_steps when use adam with PGPE

    Bug of center_lr_decay_steps when use adam with PGPE

    Bug

    When use adam with PGPE this code

    self._opt_state = self._opt_update(
                self._t // self._lr_decay_steps, -grad_center, self._opt_state
            )
    

    means adam t will increase after every self._lr_decay_steps. And it means mhat and vhat will not work as moving average because (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) will be very small always. (bellow is adam update code)

    def update(i, g, state):
        x, m, v = state
        m = (1 - b1) * g + b1 * m  # First  moment estimate.
        v = (1 - b2) * jnp.square(g) + b2 * v  # Second moment estimate.
        mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
        vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
        x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
        return x, m, v
    

    Suggestion

    I think it is better to change this code to

    step_size=lambda x: self._center_lr * jnp.power(decay_coef, x // self._lr_decay_steps),
    

    and to remove self._lr_decay_steps at

    self._opt_state = self._opt_update(
                self._t, -grad_center, self._opt_state
            )
    
    opened by garam-kim1 1
  • Save top n models per checkpoint

    Save top n models per checkpoint

    As I understand, currently only the best model from the population is being saved in the end of the iteration. This may lead to inconsistent train/test results (due to overfitting) in some setups. Blending the top n models could potentially reduce this effect.

    Would you be interested in this feature for evojax? I can work on a PR. Seems like not all solvers can have this feature.

    opened by danielgafni 0
  • fix support for multi-dim observations

    fix support for multi-dim observations

    Hey! I found a bug in the observations normalization code. The bug occurs when the observations are not a flat array, but a multi-dim array. This happens because the obs_normalizer params are stored as a flat array. The code fails in this case. Here is the fix for this bug.

    opened by danielgafni 5
  • JAX implementation of CMAES

    JAX implementation of CMAES

    Hi, I'm really amazed by this library.

    Currently, CMAES is just a wrapper. I implemented a JAX CMAES based on https://github.com/CyberAgentAILab/cmaes/.

    opened by moskomule 7
Releases(v0.2.15)
Owner
Google
Google ❤️ Open Source
Google
Fast and scalable uncertainty quantification for neural molecular property prediction, accelerated optimization, and guided virtual screening.

Evidential Deep Learning for Guided Molecular Property Prediction and Discovery Ava Soleimany*, Alexander Amini*, Samuel Goldman*, Daniela Rus, Sangee

Alexander Amini 75 Dec 15, 2022
(Py)TOD: Tensor-based Outlier Detection, A General GPU-Accelerated Framework

(Py)TOD: Tensor-based Outlier Detection, A General GPU-Accelerated Framework Background: Outlier detection (OD) is a key data mining task for identify

Yue Zhao 127 Jan 5, 2023
Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Self-attention building blocks for computer vision applications in PyTorch Implementation of self attention mechanisms for computer vision in PyTorch

AI Summer 962 Dec 23, 2022
a general-purpose Transformer based vision backbone

Swin Transformer By Ze Liu*, Yutong Lin*, Yue Cao*, Han Hu*, Yixuan Wei, Zheng Zhang, Stephen Lin and Baining Guo. This repo is the official implement

Microsoft 9.9k Jan 8, 2023
BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation This is a demo implementation of BYOL for Audio (BYOL-A), a self-sup

NTT Communication Science Laboratories 160 Jan 4, 2023
ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection

ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection This repository contains implementation of the

Visual Understanding Lab @ Samsung AI Center Moscow 190 Dec 30, 2022
A task-agnostic vision-language architecture as a step towards General Purpose Vision

Towards General Purpose Vision Systems By Tanmay Gupta, Amita Kamath, Aniruddha Kembhavi, and Derek Hoiem Overview Welcome to the official code base f

AI2 79 Dec 23, 2022
ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representation from common sense knowledge graphs.

ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representa

Bats Research 94 Nov 21, 2022
A general-purpose, flexible, and easy-to-use simulator alongside an OpenAI Gym trading environment for MetaTrader 5 trading platform (Approved by OpenAI Gym)

gym-mtsim: OpenAI Gym - MetaTrader 5 Simulator MtSim is a simulator for the MetaTrader 5 trading platform alongside an OpenAI Gym environment for rein

Mohammad Amin Haghpanah 184 Dec 31, 2022
Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

MobileViT RegNet Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TR

Hong-Jia Chen 91 Dec 2, 2022
General purpose GPU compute framework for cross vendor graphics cards (AMD, Qualcomm, NVIDIA & friends)

General purpose GPU compute framework for cross vendor graphics cards (AMD, Qualcomm, NVIDIA & friends). Blazing fast, mobile-enabled, asynchronous and optimized for advanced GPU data processing usecases. Backed by the Linux Foundation.

The Kompute Project 1k Jan 6, 2023
A general-purpose programming language, focused on simplicity, safety and stability.

The Rivet programming language A general-purpose programming language, focused on simplicity, safety and stability. Rivet's goal is to be a very power

The Rivet programming language 17 Dec 29, 2022
SIMULEVAL A General Evaluation Toolkit for Simultaneous Translation

SimulEval SimulEval is a general evaluation framework for simultaneous translation on text and speech. Requirement python >= 3.7.0 Installation git cl

Facebook Research 48 Dec 28, 2022
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 1, 2023
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17k Feb 11, 2021
GPU-Accelerated Deep Learning Library in Python

Hebel GPU-Accelerated Deep Learning Library in Python Hebel is a library for deep learning with neural networks in Python using GPU acceleration with

Hannes Bretschneider 1.2k Dec 21, 2022
Accelerated deep learning R&D

Accelerated deep learning R&D PyTorch framework for Deep Learning research and development. It focuses on reproducibility, rapid experimentation, and

Catalyst-Team 3.1k Jan 6, 2023
3D ResNet Video Classification accelerated by TensorRT

Activity Recognition TensorRT Perform video classification using 3D ResNets trained on Kinetics-400 dataset and accelerated with TensorRT P.S Click on

Akash James 39 Nov 21, 2022