Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

Overview

tests badge pypi badge docs badge license badge

coax

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

readthedocs

For the full documentation, including many examples, go to https://coax.readthedocs.io/

Install

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version. To install without CUDA, simply run:

$ pip install jaxlib jax coax --upgrade

If you do require CUDA support, please check out the Installation Guide.

Getting Started

Have a look at the Getting Started page to train your first RL agent.


Comments
  • Quantile Q-Learning Implementation

    Quantile Q-Learning Implementation

    This PR adds a QuantileQ class with function types 3 and 4 that accept a number of quantiles together with the state (and action), as well as a QuantileQLearning class. The QuantileQ function could be merged into the Q class which would simplify the user-facing API. However, some more work needs to be done to incorporate the QuantileQLearning class into the QLearning class. I just wanted to validate that this is the correct approach to take to implement the IQN.

    There still is some documentation for the quantile huber loss missing and the notebooks need to be added and tuned.

    Closes https://github.com/coax-dev/coax/issues/3

    opened by frederikschubert 11
  • Add DeepMind Control Suite Example

    Add DeepMind Control Suite Example

    This PR is a rework of https://github.com/coax-dev/coax/pull/26 and adds an example for using SAC on the Walker.walk task from the DeepMind Control Suite.

    Depends on https://github.com/coax-dev/coax/pull/27 and https://github.com/coax-dev/coax/pull/28

    opened by frederikschubert 6
  • Assertion assert_equal_shape failed for MultiDiscrete action space

    Assertion assert_equal_shape failed for MultiDiscrete action space

    First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

    So my questions are:

    1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
    2. Did I declare my policy function properly?
    3. Any general ideas on how to debug JAX functions?

    I would appreciate any feedback. Thank you!

    Error message

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    Input In [25], in <cell line: 5>()
         13     transition_batch = tracer.pop()
         14     Gn = transition_batch.Rn
    ---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
         16     env.record_metrics(metrics)
         17 if done:
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
        127 def update(self, transition_batch, Adv):
        128     r"""
        129 
        130     Update the model parameters (weights) of the underlying function approximator.
       (...)
        147 
        148     """
    --> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
        150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
        151         raise RuntimeError(f"found nan's in grads: {grads}")
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
        212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
        213     warnings.warn(
        214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
        215         "should be non-zero. Please sample actions with their propensities: "
        216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
        217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
    --> 218 return self._grad_and_metrics_func(
        219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
        220     transition_batch, Adv)
    
    File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
         58 def __call__(self, *args, **kwargs):
    ---> 59     return self._jitted_func(*args, **kwargs)
    
        [... skipping hidden 14 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
         77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
         78     grads_func = jax.grad(loss_func, has_aux=True)
         79     grads, (metrics, state_new) = \
    ---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
         82     # add some diagnostics of the gradients
         83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))
    
        [... skipping hidden 10 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
         45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
         46     objective, (dist_params, log_pi, state_new) = \
    ---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
         49     # flip sign to turn objective into loss
         50     loss = -objective
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
         49 W = jnp.clip(transition_batch.W, 0.1, 10.)
         51 # some consistency checks
    ---> 52 chex.assert_equal_shape([W, Adv, log_pi])
         53 chex.assert_rank([W, Adv, log_pi], 1)
         54 objective = W * Adv * log_pi
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
        195 else:
        196   try:
    --> 197     host_assertion(*args, **kwargs)
        198   except jax.errors.ConcretizationTypeError as exc:
        199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
        200            "likely that it tried to access tensors' values during tracing. "
        201            "Make sure that you defined a jittable version of this Chex "
        202            "assertion.")
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
        154     custom_message = custom_message.format(*custom_message_format_vars)
        155   error_msg = f"{error_msg} [{custom_message}]"
    --> 157 raise exception_type(error_msg)
    
    AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].
    

    Example data

    ExampleData(
      inputs=Inputs(
        args=ArgsType2(
          S={
            'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
          is_training=True)
        static_argnums=(
          1))
      output=(
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))
    

    Policy function

    def pi(S, is_training):
        module = CustomizedModule()
        res = tuple([{"logits": item} for item in module(S["features"])])
        return res
    
    question 
    opened by xiangyuy 5
  • 'linear/w' does not match shape

    'linear/w' does not match shape

    I've been starting to learn about RL and have been trying to get coax up and running, but have run into an issue that I'm not sure how to resolve. I'm doing Q-learning on a custom gym environment, and I can run the following pieces successfully:

    q = coax.Q(func_q, env)
    pi = coax.Policy(func_pi, env)
    
    qlearning = coax.td_learning.QLearning(q, pi_targ=pi, optimizer=optax.adam(0.001))
    cache = coax.reward_tracing.NStep(n=1, gamma=0.9)
    

    Additionally, my setup passes the simple checks of:

    data = coax.Q.example_data(env) # Looks good
    ...
    s = env.observation_space.sample()
    a = env.action_space.sample()
    print(q(s,a)) # 0.0
    ...
    a = pi(s)
    print(a) # [0, 0, 0, 0, 0] as I have a MultiDiscrete action space
    

    However, once I get to actually running the training loop:

    for ep in range(50):
      pi.epsilon = 0.1
      s = env.reset()
    
      for t in range(env.maxGuesses):
        a = pi(s)
        s_next, r, done, info = env.step(a)
    
        # update
        cache.add(s, a, r, done)
    
        while cache:
          transition_batch = cache.pop()
          metrics = qlearning.update(transition_batch)
          env.record_metrics(metrics)
    
        if done:
          break
    
        s = s_next
    
        # early stopping
        if env.avg_G > env.reward_threshold:
          break
    

    I get a bunch of errors with the most human-readable of them saying:

    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    By adjusting the parameters of the environment, I can adjust what the numbers that are mismatched are. I can't get them to match and either way that seems like the wrong solution as something more fundamental seems to be the issue.

    For reference, here are my functions for q and pi:

    def func_pi(S, is_training):
      logits = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(Wordle.wordLength*len(alphabet), w_init=jnp.zeros) # This many possible actions
      ))
      # First, convert to a vector:
      sVec = state_to_vec(S)
    
      # Now get the output:
      logitVec = logits(sVec)
    
      # Now chunk the output into alphabet-sized pieces (definitionally an integral
      # number of them). There will be Wordle.wordLength chunks of this length
      chunks = jnp.split(logitVec, Wordle.wordLength)
    
      # Now format our output array:
      ret = []
      for chunk in chunks:
        ret.append({'logits': jnp.reshape(chunk,(1,len(alphabet)))})
    
      return tuple(ret)
    
    # and for actual state:
    def func_q(S, A, is_training):
      value = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(30), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
      ))
    
      sVec = state_to_vec(S)
      aVec = action_to_vec(A)
    
      X = jnp.concatenate((sVec, aVec))
      return value(X)
    

    Note that state_to_vec(S) and action_to_vec(A) just convert from my internal types to jnp.array's for use with Haiku.

    I'm quite new to coax/JAX/Haiku so it's entirely possible I've set something up wrong. For completeness here's the full text of the error:

    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 596, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
        compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
        ans = call(fun, *args)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
        return lower_xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1623, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 996, in grad_f_aux
        (_, aux), g = value_and_grad_f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 1067, in value_and_grad_f
        ans, vjp_py, aux = _vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 2478, in _vjp
        out_primal, out_vjp, aux = ad.vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 118, in vjp
        out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 103, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 520, in trace_to_jaxpr
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 324, in process_call
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 204, in process_call
        jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 317, in partial_eval
        out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1364, in process_call
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    Please let me know if other information would be useful or relevant (or let me know if this isn't actually a coax issue...).

    Thanks for your help and the neat package.

    bug good first issue question 
    opened by bcerjan 5
  • DQN pong example doesn't work off the shelf

    DQN pong example doesn't work off the shelf

    Describe the bug

    Running the DQN example on pong generates the following error when generating a gif:

      File ".../lib/python3.9/site-packages/coax/utils/_misc.py", line 475, in generate_gif
        assert env.render_mode == 'rgb_array', "env.render_mode must be 'rgb_array'"
    

    This is likely due to some recent updates to gym. Currently, on gym==0.26.2 I observe the following:

    import gym
    env = gym.make('PongNoFrameskip-v4', render_mode="rgb_array")
    print(env.render_mode) # prints None
    
    opened by thisiscam 4
  • Add dm_control example for SAC

    Add dm_control example for SAC

    This PR introduces the common squashed normal distribution for the SAC policy on dm_control and provides an example that solves the walker.walk task. Interestingly clipping the actions to the range [-1, 1] diverges. rendering

    @KristianHolsheimer How would you go about changing the installation script for this notebook to add dm_control as a dependency?

    opened by frederikschubert 4
  • Frozen Lake example has an invalid gym signature.

    Frozen Lake example has an invalid gym signature.

    Describe the bug

    The example for Frozen Lake in the main branch of the docs isn't fully updated for the new version of gym's signature.

    ValueError Traceback (most recent call last) in 77 78 a = pi.mode(s) ---> 79 s, r, done, info = env.step(a) 80 81 env.render()

    ValueError: too many values to unpack (expected 4)

    Expected behavior

    Executing the notebook should not result in a ValueError.

    To Reproduce

    Colab notebook to repro the bug:

    - https://colab.research.google.com/...

    Runtime used for this colab notebook: ... (e.g. CPU/GPU/TPU)

    Any.

    Additional context

    Simple fix, happy to contribute a pull request.

    opened by dbleyl 3
  • Incorporating jax.jit into a customer policy

    Incorporating jax.jit into a customer policy

    I'm a bit new to JAX so my question might sound very naive. Suppose we are trying to solve a policy optimization problem through REINFORCE algorithm and suppose we already have our environment at hand (env). We define our customer policy as follows,

    class CustomPolicy(hk.Module):
        def __init__(self, name = None):
            super().__init__(name = name)
        
    
        def __call__(self, x):
            w = hk.get_parameter("w", shape= ... , dtype = x.dtype, init=jnp.zeros)
            # some computation
            return out
    

    Per the documentation, then we define

    def custom_policy(S, is_training=True):
        logits = CustomPolicy()
        return {'logits': logits(S)}
    

    and finally the policy is stated as follows,

    pi = coax.Policy(custom_policy, env)

    I was wondering is there any way to incorporate @jax.jit into this structure to further quicken the performance. Thanks.

    question 
    opened by UweGensheimer 3
  • Multi-Step Entropy Regularization for SAC

    Multi-Step Entropy Regularization for SAC

    • Add record_extra_info flag to the NStep tracer that records the intermediate states in the new extra_info field to TransitionBatch
    • Add support for the NStepEntropyRegularizer in SoftPG

    This PR contains an initial working implementation of the mechanism and sums um the discounted entropy bonuses of the states s_t, s_{t + 1}, ... , s_{t + n - 1} for the soft policy gradient regularization.

    opened by frederikschubert 3
  • Implementation of SAC

    Implementation of SAC

    Since SAC is really similar to TD3, we are able to re-use most of its components. The differences are:

    • The actions to update the q-functions and policy are sampled using the current policy (instead of taking the mode).
    • There is no target policy.
    • The log variance of the policy depends on the state.
    • The policy is entropy regularized.

    The current implementation does not support multi-step td-learning.

    opened by frederikschubert 3
  • AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    Hi, unsure if this is a due to coax or jax but I get this error when running the pendulum ppo example, dqn runs fine however.

    A similar error I found online recommended changing the version of jaxlib so I went to using the jaxlib version set out in the coax getting started guide but seemed to have no affect jax version = 0.2.13 jaxlib version = 0.1.65 + cuda111 coax version = 0.1.6

    question 
    opened by mmcaulif 3
  • Recurrent Experience Replay

    Recurrent Experience Replay

    Is your feature request related to a problem? Please describe.

    It seems that the implemented replay buffers only operate over transitions, with no ability to operate over entire sequences. This prevents the use of recurrent policies for tackling POMDPs.

    Describe the solution you'd like

    A SequenceReplayBuffer that returns contiguous episodes instead of shuffled transitions.

    Describe alternatives you've considered

    Additional context

    enhancement 
    opened by smorad 3
  • MiniMax Algorithm?

    MiniMax Algorithm?

    How would you implement a minimax q-learner with coax?

    Hi there! I love the package and how accessible it is to relative newbies. The tutorials are pretty great and the accompanying videos are very helpful!

    I was wondering what the best way to implement a minimax algorithm would be, would you recommend using two policies pi1 and pi2? Or is there something better suited for this?

    I'd like to re-implement something like this old blogpost of mine in coax to get a better feel of the library.

    Any help would be greatly appreciated :)

    question 
    opened by flaport 1
  • Convert Numpy Docstrings to Google Style

    Convert Numpy Docstrings to Google Style

    This issue tracks the progress of converting the numpy style docstrings to the more concise Google style.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    This depends on the type annotations https://github.com/coax-dev/coax/issues/13 for easier automatic conversions.

    enhancement 
    opened by frederikschubert 0
  • Add Type Annotations

    Add Type Annotations

    This issue tracks the progress of adding type annotations to coax.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    The types are added by utilising pyannotate and adding the following snippet to the coax._base.TestCase class:

    ...
    @classmethod
        def setUpClass(cls) -> None:
            collect_types.init_types_collection()
            collect_types.start()
    
        @classmethod
        def tearDownClass(cls) -> None:
            collect_types.stop()
            type_replacements = {
                "jaxlib.xla_extension.DeviceArray": "jax.numpy.ndarray",
                "haiku._src.data_structures.FlatMapping": "typing.Mapping",
                "coax._core.policy_test": "gym.Env"
            }
            types_str = collect_types.dumps_stats()
            for inferred_type, replacement in type_replacements.items():
                types_str = types_str.replace(inferred_type, replacement)
            with open(sys.modules[cls.__module__].__file__.replace(".py", "_types.json"), "w") as f:
                f.write(types_str)
    ...
    

    and the types are added automatically

    for t in coax/**/*_test_types.json
    do
        pyannotate --type-info $t -3 coax/* -w
    done
    
    enhancement 
    opened by frederikschubert 0
  • PPOClip grad update seems to cause inf update

    PPOClip grad update seems to cause inf update

    Describe the bug Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.

    During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.

    Expected behavior

    ...
    adv = np.random.rand(32)
    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
    print("grads", grads)
    print(ppo_clip._pi.params)
    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
    print(ppo_clip._pi.params)
    

    Results in:

    grads FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([ 0.0477 , -0.02505, -0.05048,  0.02798], dtype=float16),
                  'w': DeviceArray([[ 0.01338 , -0.01921 , -0.01038 ,  0.01622 ],
                                    [ 0.02406 , -0.01683 , -0.02039 ,  0.01316 ],
                                    [ 0.0332  , -0.0227  , -0.03108 ,  0.02061 ],
                                    ...,
                                    [ 0.02452 , -0.00956 , -0.01997 ,  0.005024],
                                    [ 0.010025,  0.001724, -0.03467 ,  0.02295 ],
                                    [ 0.01886 , -0.01413 , -0.01494 ,  0.01022 ]], dtype=float16),
                }),
    
    FlatMapping({
      'linear': FlatMapping({
                  'w': DeviceArray([[-1.0124e-02,  3.4389e-03,  2.9316e-03,  6.5498e-03],
                                    [ 3.3302e-03, -1.7233e-03, -3.0422e-03, -1.8060e-04],
                                    [-2.8908e-05, -3.3131e-03, -6.1073e-03,  6.5804e-03],
                                    ...,
                                    [-2.5597e-03,  7.3471e-03, -3.6221e-03, -5.6801e-03],
                                    [-7.3471e-03, -3.7746e-03,  5.8746e-03,  6.1531e-03],
                                    [-1.1940e-03,  6.9733e-03, -5.0507e-03,  3.4218e-03]],            dtype=float16),
                  'b': DeviceArray([0., 0., 0., 0.], dtype=float16),
                }),
    })
    
    FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([-0.001002,  0.000978,  0.001001, -0.001007], dtype=float16),
                  'w': DeviceArray([[-0.01111  ,  0.004448 ,  0.00386  ,  0.00551  ],
                                    [ 0.002354 , -0.0007563, -0.002048 , -0.001162 ],
                                    [-0.001021 , -0.002335 , -0.005104 ,  0.005558 ],
                                    ...,
                                    [-0.003561 ,  0.008224 , -0.002628 ,       -inf],
                                    [-0.00828  ,       -inf,  0.006874 ,  0.00515  ],
                                    [-0.002203 ,  0.00804  , -0.004086 ,  0.002493 ]],            dtype=float16),
                }),
    

    Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:

    import os
    from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
    from luxai2021.game.game import Game
    from luxai2021.game.actions import *
    from luxai2021.game.constants import LuxMatchConfigs_Default
    
    from luxai2021.env.agent import Agent, AgentWithTeamModel
    import numpy as np
    
    from agent import TeamAgent
    
    # set some env vars
    os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')     # tell JAX to use GPU
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet
    
    import gym
    import jax
    import coax
    import haiku as hk
    import jax.numpy as jnp
    from optax import adam
    
    
    # the name of this script
    name = 'ppo'
    
    configs = LuxMatchConfigs_Default
    
    player = TeamAgent(mode="train")
    opponent = Agent()
    
    env = LuxEnvironment(configs=configs,
                                    learning_agent=player,
                                    opponent_agent=opponent)
    env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
    
    def func_pi(S, is_training):
        n_actions = 4
        out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
        return out
    
    def func_v(S, is_training):
        h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
        return h
    
    '''
    def func_pi(S, is_training):
        #print(env.action_space.shape)
        n_filters = 5
        n_actions = 4
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))
        
        print('h', type(h), h.shape)
        h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
        h_head_actions = hk.Linear(n_actions)(h_head)
        print('h_head_actions', type(h_head_actions), h_head_actions.shape)
        #print(h_head_actions)
    
        out = {'logits': h_head_actions}
        
        return out
    
    def func_v(S, is_training):
        n_filters = 5
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))
    
        h = hk.Flatten()(h)
        h = jax.nn.relu(hk.Linear(64)(h))
        h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))
        
        return h
    '''
    
    
    # function approximators
    pi = coax.Policy(func_pi, env)
    v = coax.V(func_v, env)
    
    # target networks
    pi_behavior = pi.copy()
    v_targ = v.copy()
    
    # policy regularizer (avoid premature exploitation)
    entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)
    
    # updaters
    simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
    ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))
    
    # reward tracer and replay buffer
    tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
    buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)
    
    # run episodes
    max_episode_steps = 400
    while env.T < 3000000:
        s = env.reset()
    
        for t in range(max_episode_steps):
            print(t)
            a, logp = pi_behavior(s, return_logp=True)
            s_next, r, done, info = env.step(a)
    
            # trace rewards and add transition to replay buffer
            tracer.add(s, a, r, done, logp)
            while tracer:
                buffer.add(tracer.pop())
    
            # learn
            if len(buffer) >= buffer.capacity:
                num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
                for i in range(num_batches):
                    transition_batch = buffer.sample(32)
                    grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
                    metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
    
                    
                    adv = np.random.rand(32)
                    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
                    print("grads", grads)
                    print(ppo_clip._pi.params)
                    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
                    print(ppo_clip._pi.params)
                    exit()
                    env.record_metrics(metrics_pi)
                    env.record_metrics(metrics_v)
                    
    
                buffer.clear()
    
                # sync target networks
                pi_behavior.soft_update(pi, tau=0.1)
                v_targ.soft_update(v, tau=0.1)
    
            if done:
                break
    
            s = s_next
    
        # generate an animated GIF to see what's going on
        if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
            T = env.T - env.T % 10000  # round to 10000s
            coax.utils.generate_gif(
                env=env, policy=pi, resize_to=(320, 420),
                filepath=f"./data/gifs/{name}/T{T:08d}.gif")
    
    
    opened by glmcdona 3
Releases(v0.1.12)
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Deep Q Learning with OpenAI Gym and Pokemon Showdown

pokemon-deep-learning An openAI gym project for pokemon involving deep q learning. Made by myself, Sam Little, and Layton Webber. This code captures g

null 2 Dec 22, 2021
Customizable RecSys Simulator for OpenAI Gym

gym-recsys: Customizable RecSys Simulator for OpenAI Gym Installation | How to use | Examples | Citation This package describes an OpenAI Gym interfac

Xingdong Zuo 14 Dec 8, 2022
Manipulation OpenAI Gym environments to simulate robots at the STARS lab

Manipulator Learning This repository contains a set of manipulation environments that are compatible with OpenAI Gym and simulated in pybullet. In par

STARS Laboratory 5 Dec 8, 2022
An OpenAI Gym environment for Super Mario Bros

gym-super-mario-bros An OpenAI Gym environment for Super Mario Bros. & Super Mario Bros. 2 (Lost Levels) on The Nintendo Entertainment System (NES) us

Andrew Stelmach 1 Jan 5, 2022
A plug-and-play library for neural networks written in Python

A plug-and-play library for neural networks written in Python!

Dimos Michailidis 2 Jul 16, 2022
Plug and play transformer you can find network structure and official complete code by clicking List

Plug-and-play Module Plug and play transformer you can find network structure and official complete code by clicking List The following is to quickly

null 8 Mar 27, 2022
Reinforcement Learning with Q-Learning Algorithm on gym's frozen lake environment implemented in python

Reinforcement Learning with Q Learning Algorithm Q learning algorithm is trained on the gym's frozen lake environment. Libraries Used gym Numpy tqdm P

null 1 Nov 10, 2021
计算机视觉中用到的注意力模块和其他即插即用模块PyTorch Implementation Collection of Attention Module and Plug&Play Module

PyTorch实现多种计算机视觉中网络设计中用到的Attention机制,还收集了一些即插即用模块。由于能力有限精力有限,可能很多模块并没有包括进来,有任何的建议或者改进,可以提交issue或者进行PR。

PJDong 599 Dec 23, 2022
Gradient Step Denoiser for convergent Plug-and-Play

Source code for the paper "Gradient Step Denoiser for convergent Plug-and-Play"

Samuel Hurault 11 Sep 17, 2022
🐥A PyTorch implementation of OpenAI's finetuned transformer language model with a script to import the weights pre-trained by OpenAI

PyTorch implementation of OpenAI's Finetuned Transformer Language Model This is a PyTorch implementation of the TensorFlow code provided with OpenAI's

Hugging Face 1.4k Jan 5, 2023
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Trading Gym is an open source project for the development of reinforcement learning algorithms in the context of trading.

Trading Gym Trading Gym is an open-source project for the development of reinforcement learning algorithms in the context of trading. It is currently

Dimitry Foures 535 Nov 15, 2022
gym-anm is a framework for designing reinforcement learning (RL) environments that model Active Network Management (ANM) tasks in electricity distribution networks.

gym-anm is a framework for designing reinforcement learning (RL) environments that model Active Network Management (ANM) tasks in electricity distribution networks. It is built on top of the OpenAI Gym toolkit.

Robin Henry 99 Dec 12, 2022
Multi-objective gym environments for reinforcement learning.

MO-Gym: Multi-Objective Reinforcement Learning Environments Gym environments for multi-objective reinforcement learning (MORL). The environments follo

Lucas Alegre 74 Jan 3, 2023
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
BasicRL: easy and fundamental codes for deep reinforcement learning。It is an improvement on rainbow-is-all-you-need and OpenAI Spinning Up.

BasicRL: easy and fundamental codes for deep reinforcement learning BasicRL is an improvement on rainbow-is-all-you-need and OpenAI Spinning Up. It is

RayYoh 12 Apr 28, 2022
A Jupyter notebook to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

A Jupyter notebook to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

Eugenio Herrera 175 Dec 29, 2022