Git Product home page Git Product logo

evojax's Introduction

EvoJAX: Hardware-Accelerated Neuroevolution

EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs. EvoJAX achieves very high performance by implementing the evolution algorithm, neural network and task all in NumPy, which is compiled just-in-time to run on accelerators.

This repo also includes several extensible examples of EvoJAX for a wide range of tasks, including supervised learning, reinforcement learning and generative art, demonstrating how EvoJAX can run your evolution experiments within minutes on a single accelerator, compared to hours or days when using CPUs.

EvoJAX paper: https://arxiv.org/abs/2202.05008 (presentation video)

Please use this BibTeX if you wish to cite this project in your publications:

@article{evojax2022,
  title={EvoJAX: Hardware-Accelerated Neuroevolution},
  author={Tang, Yujin and Tian, Yingtao and Ha, David},
  journal={arXiv preprint arXiv:2202.05008},
  year={2022}
}

List of publications using EvoJAX (please open a PR to add missing entries):

Installation

EvoJAX is implemented in JAX which needs to be installed first.

Install JAX: Please first follow JAX's installation instruction with optional GPU/TPU backend support. In case JAX is not set up, EvoJAX installation will still try pulling a CPU-only version of JAX. Note that Colab runtimes come with JAX pre-installed.

Install EvoJAX:

# Install from PyPI.
pip install evojax

# Or, install from our GitHub repo.
pip install git+https://github.com/google/evojax.git@main

If you also want to install the extra dependencies required for certain optional functionalities, use

pip install evojax[extra]
# Or
pip install git+https://github.com/google/evojax.git@main#egg=evojax[extra]

Code Overview

EvoJAX is a framework with three major components, which we expect the users to extend.

  1. Neuroevolution Algorithms All neuroevolution algorithms should implement the evojax.algo.base.NEAlgorithm interface and reside in evojax/algo/. See here for the available algorithms in EvoJAX.
  2. Policy Networks All neural networks should implement the evojax.policy.base.PolicyNetwork interface and be saved in evojax/policy/. In this repo, we give example implementations of the MLP, ConvNet, Seq2Seq and PermutationInvariant models.
  3. Tasks All tasks should implement evojax.task.base.VectorizedTask and be in evojax/task/.

These components can be used either independently, or orchestrated by evojax.trainer and evojax.sim_mgr that manage the training pipeline. While they should be sufficient for the currently provided policies and tasks, we plan to extend their functionality in the future as the need arises.

Examples

As a quickstart, we provide non-trivial examples (scripts in examples/ and notebooks in examples/notebooks) to illustrate the usage of EvoJAX. We provide example commands to start the training process at the top of each script. These scripts and notebooks are run with TPUs and/or NVIDIA V100 GPU(s):

Supervised Learning Tasks

While one would obviously use gradient-descent for such tasks in practice, the point is to show that neuroevolution can also solve them to some degree of accuracy within a short amount of time, which will be useful when these models are adapted within a more complicated task where gradient-based approaches may not work.

  • MNIST Classification - We show that EvoJAX trains a ConvNet policy to achieve >98% test accuracy within 5 min on a single GPU.
  • Seq2Seq Learning - We demonstrate that EvoJAX is capable of learning a large network with hundreds of thousands parameters to accomplish a seq2seq task.

Classic Control Tasks

The purpose of including control tasks are two-fold: 1) Unlike supervised learning tasks, control tasks in EvoJAX have undetermined number of steps, we thus use these examples to demonstrate the efficiency of our task roll-out loops. 2) We wish to show the speed-up benefit of implementing tasks in JAX and illustrate how to implement one from scratch.

  • Locomotion - Brax is a differentiable physics engine implemented in JAX. We wrap it as a task and train with EvoJAX on GPUs/TPUs. It takes EvoJAX tens of minutes to solve a locomotion task in Brax.
  • Cart-Pole Swing Up - We illustrate how the classic control task can be implemented in JAX and be integrated into EvoJAX's pipeline for significant speed up training.

Novel Tasks

In this last category, we go beyond simple illustrations and show examples of novel tasks that are more practical and attractive to researchers in the genetic and evolutionary computation area, with the goal of helping them try out ideas in EvoJAX.

Multi-agent WaterWorld ES-CLIP: “A drawing of a cat” Slime Volleyball
  • WaterWorld - In this task, an agent tries to get as much food as possible while avoiding poisons. EvoJAX is able to train the agent in tens of minutes on a single GPU. Moreover, we demonstrate that multi-agents training in EvoJAX is possible, which is beneficial for learning policies that can deal with environmental complexity and uncertainties.
  • Abstract Paintings (notebook 1 and notebook 2) - We reproduce the results from this computational creativity work and show how the original work, whose implementation requires multiple CPUs and GPUs, could be accelerated on a single GPU efficiently using EvoJAX, which was not possible before. Moreover, with multiple GPUs/TPUs, EvoJAX can further speed up the mentioned work almost linearly. We also show that the modular design of EvoJAX allows its components to be used independently -- in this case it is possible to use only the ES algorithms from EvoJAX while leveraging one's own training loops and environment implantation.
  • Neural Slime Volleyball - In this task, the agent's goal is to get the ball to land on the ground of its opponent's side, causing its opponent to lose a life. The episode ends when either agent loses all five lives, or after the time limit. An agent receives a reward of +1 when its opponent loses or -1 when it loses a life. EvoJAX is able to train the agent in under 5 minutes on a single GPU, compared hours on multiple CPUs. This implementation is based on the Slime Volleyball Gym Environment, which is a Python port of the original JavaScript version of the game that you can play in the web browser. In all of these versions, the built-in AI opponent and the less-than-ideal physics are identical.

Call for Contributions

The goal of EvoJAX is to get evolutionary computation to able to work on a vast array of tasks using accelerators.

One issue before was that many evolution algorithms were only optimized for one particular task for some paper. This is the reason we focused only on one single algorithm (PGPE) in the first release of EvoJAX, while creating 6+ different tasks in diverse domains, ensuring that one single algorithm works for all of the tasks without any issues. See Table of contributed algorithms.

Evolutionary Algorithms

We welcome new evolution algorithms to be added to this toolkit. It would be great if you can show that your implementation can perform on cart-pole swing-up (hardmode), BRAX, waterworld, and MNIST, before submitting a pull request.

Ideas for evolutionary algorithm candidates:

  • Your favorite Genetic Algorithm.
  • CMA-ES (bare version, and improved versions such as BIPOP-CMA-ES)
  • Augmented Random Search (paper)
  • AMaLGaM-IDEA (paper)

We suggest the below performance guidelines for new algorithms:

  1. MNIST: 90%+
  2. Cartpole: 900+ (easy), 600+ (hard)
  3. Waterworld: 6+ (single-agent), 2+ (multiiagent)
  4. Brax ant: 3000+

Note that these are not hard requirements, but just rough guidelines.

Please use the benchmark script to evaluate your algorithm before sending us a PR, let us know if you are unable to test on some tasks due to hardware limitations. See this example pull request thread of a Genetic Algorithm that has been merged into EvoJAX to see how it should be done.

Feel free to reach out to [email protected] if you wish to discuss further.

New Tasks

We also welcome new tasks and examples (see here for all tasks in EvoJAX). Some suggestions:

  • Train a Neural Turing Machine using evolution to come up with a sorting algorithm.
  • Soccer via self-play (Example)
  • Evolving Hebbian Learning-capable plastic networks that can remember the map of a maze from the agent’s recent experience.
  • Adaptive Computation Time for RNNs performing a task that requires an unknown number of steps.
  • Tasks that make use of hard attention.

Sister Projects

There is a growing number of researchers working with evolutionary computation who are using JAX. Here is a list of related efforts:

  • QDax: Accelerated Quality-Diversity. A tool that uses JAX to help accelerate Quality-Diveristy (QD) algorithms through hardware accelerators and massive parallelism. (GitHub | paper)

  • evosax: A JAX-based library of evolution strategies focusing on JAX-composable ask-tell functionality and strategy diversity. More than 10 ES algorithms implemented. (GitHub)

Disclaimer

This is not an official Google product.

evojax's People

Contributors

alantian avatar bastianzim avatar chakazul avatar danielgafni avatar dietmarwo avatar edoardopona avatar flipchip167 avatar hardmaru avatar lerrytang avatar llionj avatar maximilienlc avatar mayu-snba19 avatar roberttlange avatar surya-77 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

evojax's Issues

[Question] Are the jits around pmaps intended?

self._train_rollout_fn = jax.jit(jax.pmap(

The docs seem to say that jits around pmaps are unnecessary: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#pmap-and-jit

While running experiments, I also often get this warning which seems to say that it might be problematic:
UserWarning: The jitted function <unnamed function> includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See [https://github.com/google/jax/issues/2926].

If the behaviour is intended please discard.

How to design custom Seq2seq model by evojax?

Dear developer:
i am developing a large seq2seq model through evojax. However, i found it is inflexible to develop my custom model with custom vocabulary based on the example seq2seq, which means i need to revise the source code, not import functions or revise parameters. I would appreciate it if you could give me some guidance.
By the way, thanks for your awesome work.

Save top n models per checkpoint

As I understand, currently only the best model from the population is being saved in the end of the iteration. This may lead to inconsistent train/test results (due to overfitting) in some setups. Blending the top n models could potentially reduce this effect.

Would you be interested in this feature for evojax? I can work on a PR. Seems like not all solvers can have this feature.

opencv-python dependency?

Hi everyone, and thanks for working on such project and open sourcing it!

I notice that the project has a dependency on opencv-python listed in the setup.py file. However, I did not find any import cv2 in the project, and the examples seem to work fine even in an environment with no opencv-python/cv2. Is the dependency on opencv-python just a leftover/accidental dependency, or there is something that I am missing?

Thanks in advance.

Advanced loggers

It would be nice to use Tensorboard or any other advanced logging tool with evojax.
Looks like it should be straightforward to allow the user to implement a log_reward function that could be passed to the trainer.

I'm willing to implement this in a PR.

Multi-Agent RL Environment for CrowdSim, Predator-Prey, and Army

Hello again,

I created a crowd and predator-prey environment that may be useful for ejovax. I also work now in a melee combat environment.

Here is the PyTorch implementation of the crowd and predator environments: https://github.com/kayuksel/multi-rl-crowd-sim

Here is a video where multi-agent predators are learning to surround preys to maximize hunts: https://youtu.be/Ds9O9wPyF8g

(I will also create a competitive multi-agent environment for closed-market auction where they will self-play by placing orders).

Have a nice week.

Sincerely,
Kamer

[Discussion] Sequencing side-effects in JAX

Sequencing side-effects in JAX is a known issue in JAX. I have tried to make a custom operational env/task where the step update needs to be sequential.

JAX docs say one needs to tokenize the functions to force the sequence. However, I could not find any task in EvoJax with the tokens. Did anybody ever try that? Or I should think more and redesign my problem to make it more JAX compatible since I use a bunch of jax.lax.cond (constraints are quite complicated for the problem I have).

Some proposals about the `Trainer` logic

Currently I see two ways of using the Trainer.test_task:

  1. The test_task of the trainer is used for validation. The actual test set is being holdout and not seen during training or validation. In this case, how do I run the actual test? I can't pass just the test_task to the trainer, because the train_task is non-optional. Looks like there should be a way to do this with evojax.
  2. The test_task of the trainer is used for the actual test, no validation is used at all. In this case, why does the trainer.run return the best model score and not the last model score?

I propose the following (high level) logic:

best_val_reward = trainer.fit(train_task: VectorizedTask, val_task: Optional[VectorizedTask] = None)  # maybe the user doesn't want validation (e.g. train on latest data without early stopping)
test_reward = trainer.test(test_task: VectorizedTask, checkpoint="best|last|path")  # specify which checkpoint to use for testing

Probably early stopping would be pretty necessary for the trainer.fit method. Currently there is no way to determine when to do it and even which model iteration has the best result.

I'm willing to implement this logic in a PR.

AssertionError for OpenES

When I try to instantiate OpenES from open_es.py, I get the following error message:
Schermata 2022-12-15 alle 20 23 59
I traced back the problem to line 110 in open_es.py, where both centered_rank and z_score arguments are set to True:
Schermata 2022-12-15 alle 20 26 01
But line 26 of FitnessShaper class from evosax/utils/reshape_fitness.py says that
Schermata 2022-12-15 alle 20 26 49
How to get around this issue?

AbstractPainting02.ipynb. doesn't work on colab

Hello, this is a really great code.

I was able to run "Abstract Painting 01" very well at Google coab.
However, when I ran "AbstractPainting02", an error occurred.

Exception                                 Traceback (most recent call last)
[<ipython-input-20-b16203d22159>](https://localhost:8080/#) in <module>()
      2 devices = jax.local_devices()
      3 
----> 4 image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load('ViT-B/32', "cpu")
      5 
      6 target_text_ids = jnp.array(clip_jax.tokenize([prompt])) # already with batch dim

3 frames
[/content/CLIP_JAX/clip_jax/clip.py](https://localhost:8080/#) in process_node(value, name)
    117             new_tensor = jnp.array(pytorch_tensor)
    118         else:
--> 119             raise Exception("not implemented")
    120 
    121         assert new_tensor.shape == value.shape

Exception: not implemented

Which version of clip_jax when you made?

Best

Reinitialization

Hello,

i have a task with unknown global optima and since optimizers can stuck in local optima i want to make sure the achieved optima is reached from various random starting points. Therefore i would like to incorporate some kind of reinitialization of whole search (basically starting trainer.run with multiple different seeds).
Is it even necessary? Does SimManager -> eval_params -> _for_loop_eval -> policy_reset_func perform reliable reinitialization of policy state?

Thanks in advance for your advice.

can one specify parts of the model that are non differentiable?

I have a model in Jax (convolutional neural network with some modifications) where most is fully differentiable, but parts are not - can I mark somehow which parts are not differentiable so to have correct gradient backpropagation, or it is done automatically?

Thanks!

Evaluating brax environments other than brax-ant. Terminates with error.

Information

Issue is with running brax environments other brax-ant. The included humanoid, half cheetah and fetch environments are affected.

Couldn't find any references to this issue in the repo. I could have missed something.

Expected Behavior

/home/<USER>/anaconda3/envs/evojax/bin/python /home/<USER>/evojax/scripts/benchmarks/train.py -config configs/PGPE/brax_halfcheetah.yaml
brax: 2022-06-16 20:41:01,954 [INFO] EvoJAX brax
brax: 2022-06-16 20:41:01,954 [INFO] ==============================
absl: 2022-06-16 20:41:02,137 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-06-16 20:41:02,221 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
MLPPolicy: 2022-06-16 20:41:03,747 [INFO] MLPPolicy.num_params = 3974
brax: 2022-06-16 20:41:03,787 [INFO] use_for_loop=False
brax: 2022-06-16 20:41:03,825 [INFO] Start to train for 1 iterations.
brax: 2022-06-16 20:41:56,024 [INFO] [TEST] Iter=1, #tests=1, max=-9.7476, avg=-9.7476, min=-9.7476, std=0.0000
brax: 2022-06-16 20:41:56,087 [INFO] Training done, best_score=-9.7476
brax: 2022-06-16 20:41:56,093 [INFO] Loaded model parameters from ./log/PGPE/brax/default.
brax: 2022-06-16 20:41:56,093 [INFO] Start to test the parameters.
brax: 2022-06-16 20:42:03,478 [INFO] [TEST] #tests=1, max=-9.9009, avg=-9.9009, min=-9.9009, std=0.0000

Current Behavior

brax: 2022-06-16 20:26:04,657 [INFO] EvoJAX brax
brax: 2022-06-16 20:26:04,657 [INFO] ==============================
absl: 2022-06-16 20:26:04,833 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-06-16 20:26:04,920 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
MLPPolicy: 2022-06-16 20:26:06,465 [INFO] MLPPolicy.num_params = 3974
brax: 2022-06-16 20:26:06,504 [INFO] use_for_loop=False
brax: 2022-06-16 20:26:06,541 [INFO] Start to train for 10 iterations.
Traceback (most recent call last):
  File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 88, in <module>
    main(config)
  File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 64, in main
    trainer.run(demo_mode=False)
  File "/home/<USER>/evojax/evojax/trainer.py", line 152, in run
    scores, bds = self.sim_mgr.eval_params(
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 258, in eval_params
    return self._scan_loop_eval(params, test)
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 355, in _scan_loop_eval
    scores, all_obs, masks, final_states = rollout_func(
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 202, in rollout
    (obs_set, obs_mask)) = jax.lax.scan(
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1630, in scan
    _check_tree_and_avals("scan carry output and input",
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2316, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: scan carry output and input must have identical types, got
(State(state=State(qp=QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), obs='ShapedArray(float32[16384,18])', reward='ShapedArray(float32[16384])', done='ShapedArray(float32[16384])', metrics={'reward_ctrl_cost': 'ShapedArray(float32[16384])', 'reward_forward': 'ShapedArray(float32[16384])'}, info={'first_obs': 'ShapedArray(float32[16384,18])', 'first_qp': QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), 'steps': 'ShapedArray(float32[16384])', 'truncation': 'ShapedArray(float32[16384])'}), obs='ShapedArray(float32[16384,18])', feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])'), PolicyState(keys='ShapedArray(uint32[16384,2])'), 'ShapedArray(float32[16384,3974])', 'ShapedArray(float32[37])', 'ShapedArray(float32[16384])', 'ShapedArray(float32[16384])').

Exact Error:

feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])')

Failure Information

Context

Based on commit history, this appears to be due to the changes introduced in #33 .
Manually altering variable feet_contact variable from method reset_fn in file evojax/evojax/task/brax_task.py allows for the other environments to be run.

Setup details related to the hardware are irrelevant since error occurs on the hosted colab notebook as well.

brax                         0.0.13
evojax                       0.2.11               
flax                         0.4.0
jax                          0.3.1
jaxlib                       0.3.0+cuda11.cudnn82

Steps to Reproduce

Please provide detailed steps for reproducing the issue.

  1. Run evojax/scripts/benchmarks/train.py using a modified evojax/scripts/benchmarks/configs/<ES> file using non-ant brax environment.
  2. Modify feet_contact array size and test.

Minor issue with GIF at the end of the Abstract Paintings notebook 1

The notebook (https://github.com/google/evojax/blob/main/examples/notebooks/AbstractPainting01.ipynb) does the following to turn the saved frames into a GIF:

import glob
import IPython

frames = []
imgs = glob.glob("AbstractPainting01_canvas_record.*.png")
for file in imgs:
  new_frame = Image.open(file)
  frames.append(new_frame)
frames[0].save('AbstractPainting01_final.gif', save_all=True, append_images=frames, optimize=True, duration=200, loop=0)

The resulting GIF doesn't show the frames in order, because glob returns results in an arbitrary order (probably whatever order files appear in the filesystem).

Fix:
imgs = sorted(glob.glob("AbstractPainting01_canvas_record.*.png"))

I'd submit a pull request but I don't have 8 A100s lying around so it'd take a while to re-run the example :)

Can't execute Brax notebook

Hi all, I run the notebook BraxTasks.ipynb as is, and the second cell crashes with the following error. I think it may be an issue with the Brax version.
Schermata 2023-09-24 alle 18 11 08

Issue with BatchNorm layer

while using policy network with BatchNorm layer, getting following error:

ModifyScopeVariableError: Cannot update variable "mean" in "/bn_init" because collection "batch_stats" is immutable. 
 (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ModifyScopeVariableError)

How to use GPU for computing

Dear developer,
thanks for your awesome work. I have some questions.
When i run the example of seq2seq, i got this warning:

Seq2seq: 2023-09-26 20:11:32,443 [INFO] ==============================
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695730292.463880 3694542 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-09-26 20:11:32.491777: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:276] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-09-26 20:11:32.491820: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: user-MD72-HB3-00
2023-09-26 20:11:32.491829: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: user-MD72-HB3-00
2023-09-26 20:11:32.491882: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 535.86.5
2023-09-26 20:11:32.491912: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 535.86.5
2023-09-26 20:11:32.491920: I external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 535.86.5
jax._src.xla_bridge: 2023-09-26 20:11:32,492 [INFO] Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices.
jax._src.xla_bridge: 2023-09-26 20:11:32,492 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
jax._src.xla_bridge: 2023-09-26 20:11:32,494 [INFO] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
jax._src.xla_bridge: 2023-09-26 20:11:32,494 [WARNING] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I already successfully installed GPU supported JAX, and my equipment information is LINUX with 4090 GPU and CUDA 12.2. How do i fix this problem?

A GNN-based Meta-Learning Method for Sparse Portfolio Optimization

Hello,

Let me start by saying that I am a fan of your work here. I have recently open-sourced by GNN-based meta-learning method for optimization. I have applied it to the sparse index-tracking problem from real-world (after an initial benchmarking on Schwefel function), and it seems to outperform Fast CMA-ES significantly both in terms of producing robust solutions on the blind test set and also in terms of time (total duration and iterations) and space complexity. I include the link to my repository here, in case you would consider adding the method or the benchmarking problem to your repository. Note: GNN, which learns how to generate populations of solutions at each iteration, is trained using gradients retrieved from the loss function, as opposed to black-box ones.

Sincerely, K

Bug of center_lr_decay_steps when use adam with PGPE

Bug

When use adam with PGPE this code

self._opt_state = self._opt_update(
            self._t // self._lr_decay_steps, -grad_center, self._opt_state
        )

means adam t will increase after every self._lr_decay_steps.
And it means mhat and vhat will not work as moving average because (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) will be very small always. (bellow is adam update code)

def update(i, g, state):
    x, m, v = state
    m = (1 - b1) * g + b1 * m  # First  moment estimate.
    v = (1 - b2) * jnp.square(g) + b2 * v  # Second moment estimate.
    mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
    vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
    x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
    return x, m, v

Suggestion

I think it is better to change this code to

step_size=lambda x: self._center_lr * jnp.power(decay_coef, x // self._lr_decay_steps),

and to remove self._lr_decay_steps at

self._opt_state = self._opt_update(
            self._t, -grad_center, self._opt_state
        )

Reproducing benchmark scores

Hello everyone.

I am currently currently trying to reproduce scores from the benchmarks, specifically for ARS, as I am implementing my own version native in jax, and wanted to compare with the wrapper already implemented.

For example, I cannot achieve the score posted in the benchmark table (902.107) for ARS on cartpole_easy.

running python train.py -config configs/ARS/cartpole_easy.yaml yields the following training logs

cartpole_easy: 2022-09-25 22:45:55,777 [INFO] EvoJAX cartpole_easy
cartpole_easy: 2022-09-25 22:45:55,777 [INFO] ==============================
absl: 2022-09-25 22:45:55,791 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
MLPPolicy: 2022-09-25 22:45:59,165 [INFO] MLPPolicy.num_params = 4609
cartpole_easy: 2022-09-25 22:45:59,429 [INFO] use_for_loop=False
cartpole_easy: 2022-09-25 22:45:59,496 [INFO] Start to train for 1000 iterations.
cartpole_easy: 2022-09-25 22:46:10,527 [INFO] Iter=50, size=100, max=399.5886, avg=207.9111, min=0.5843, std=99.0207
cartpole_easy: 2022-09-25 22:46:19,916 [INFO] Iter=100, size=100, max=543.8907, avg=364.9780, min=28.8478, std=141.8982
cartpole_easy: 2022-09-25 22:46:21,143 [INFO] [TEST] Iter=100, #tests=100, max=553.4018 avg=510.5583, min=462.4243, std=15.6930
cartpole_easy: 2022-09-25 22:46:30,627 [INFO] Iter=150, size=100, max=558.2020, avg=314.9279, min=89.8001, std=153.6488
cartpole_easy: 2022-09-25 22:46:40,068 [INFO] Iter=200, size=100, max=562.4118, avg=354.9529, min=47.0048, std=154.1567
cartpole_easy: 2022-09-25 22:46:40,114 [INFO] [TEST] Iter=200, #tests=100, max=570.1135 avg=547.5375, min=508.5795, std=10.0840
cartpole_easy: 2022-09-25 22:46:49,579 [INFO] Iter=250, size=100, max=562.1505, avg=325.3990, min=73.3733, std=161.9460
cartpole_easy: 2022-09-25 22:46:59,073 [INFO] Iter=300, size=100, max=569.5461, avg=370.2641, min=83.7473, std=166.8020
cartpole_easy: 2022-09-25 22:46:59,129 [INFO] [TEST] Iter=300, #tests=100, max=573.5941 avg=545.0388, min=505.8637, std=11.3853
cartpole_easy: 2022-09-25 22:47:08,623 [INFO] Iter=350, size=100, max=579.3894, avg=425.6462, min=82.4907, std=126.6614
cartpole_easy: 2022-09-25 22:47:18,109 [INFO] Iter=400, size=100, max=627.6509, avg=530.2781, min=156.4797, std=76.0956
cartpole_easy: 2022-09-25 22:47:18,160 [INFO] [TEST] Iter=400, #tests=100, max=639.7323 avg=600.9105, min=573.7767, std=10.7564
cartpole_easy: 2022-09-25 22:47:27,653 [INFO] Iter=450, size=100, max=668.2064, avg=546.0261, min=418.5385, std=60.5854
cartpole_easy: 2022-09-25 22:47:37,149 [INFO] Iter=500, size=100, max=684.4142, avg=574.4891, min=446.3126, std=62.5338
cartpole_easy: 2022-09-25 22:47:37,202 [INFO] [TEST] Iter=500, #tests=100, max=693.1522 avg=682.7945, min=638.0387, std=12.1575
cartpole_easy: 2022-09-25 22:47:46,708 [INFO] Iter=550, size=100, max=708.9561, avg=591.0547, min=295.5651, std=73.6026
cartpole_easy: 2022-09-25 22:47:56,212 [INFO] Iter=600, size=100, max=706.8138, avg=599.4783, min=348.7581, std=55.6310
cartpole_easy: 2022-09-25 22:47:56,263 [INFO] [TEST] Iter=600, #tests=100, max=691.0123 avg=680.4677, min=630.2983, std=6.1448
cartpole_easy: 2022-09-25 22:48:05,770 [INFO] Iter=650, size=100, max=707.0887, avg=581.3851, min=418.2251, std=75.9066
cartpole_easy: 2022-09-25 22:48:15,275 [INFO] Iter=700, size=100, max=712.7586, avg=586.4597, min=362.7628, std=71.5669
cartpole_easy: 2022-09-25 22:48:15,326 [INFO] [TEST] Iter=700, #tests=100, max=725.2336 avg=714.1309, min=635.7863, std=9.3471
cartpole_easy: 2022-09-25 22:48:24,849 [INFO] Iter=750, size=100, max=716.1056, avg=602.7747, min=458.0401, std=63.1697
cartpole_easy: 2022-09-25 22:48:34,365 [INFO] Iter=800, size=100, max=709.3475, avg=587.9896, min=393.0367, std=69.2385
cartpole_easy: 2022-09-25 22:48:34,418 [INFO] [TEST] Iter=800, #tests=100, max=732.5553 avg=720.5952, min=648.5032, std=8.3936
cartpole_easy: 2022-09-25 22:48:43,945 [INFO] Iter=850, size=100, max=706.8488, avg=598.3582, min=321.8640, std=75.2542
cartpole_easy: 2022-09-25 22:48:53,482 [INFO] Iter=900, size=100, max=720.0320, avg=596.1929, min=370.6555, std=77.2801
cartpole_easy: 2022-09-25 22:48:53,536 [INFO] [TEST] Iter=900, #tests=100, max=703.5345 avg=692.9500, min=677.6909, std=5.9381
cartpole_easy: 2022-09-25 22:49:03,068 [INFO] Iter=950, size=100, max=716.2341, avg=598.3802, min=422.7760, std=71.7756
cartpole_easy: 2022-09-25 22:49:12,455 [INFO] [TEST] Iter=1000, #tests=100, max=726.0114, avg=719.0803, min=698.4325, std=4.7247
cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952
cartpole_easy: 2022-09-25 22:49:12,458 [INFO] Loaded model parameters from ./log/ARS/cartpole_easy/default.
cartpole_easy: 2022-09-25 22:49:12,459 [INFO] Start to test the parameters.
cartpole_easy: 2022-09-25 22:49:12,509 [INFO] [TEST] #tests=100, max=728.9848, avg=720.6152, min=698.9832, std=5.0566

I am not entirely sure if the result on the benchmark table is intended to be 720.5952 from
cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952

or the max score from the final test. Regardless, neither of these match the one posted on the benchmark table.

Am I doing something wrong to reproduce these scores?
This makes me unable to compare my own implementation of the algorithm.

Thank you

Evolving topology of NN

Hey guys,

amazing work! Quick question, is it possible to also evolve topology of networks using your framework? Like NEAT does?

high dimensional parametric search

I'm trying to use evojax to evolve my model parameters. I found that the algorithm only accepts the parameter num_dims as the dimension, whether it can only be int type here? If I want to evolve multidimensional parameters, such as [1000x1000] data, how can I do it? Thanks!

OpenAI Gym Integration

Hi!

Is there any example of how to use Evojax with any gym-like environment that is not implemented in Jax? Is that even possible?

Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.