I've been starting to learn about RL and have been trying to get coax
up and running, but have run into an issue that I'm not sure how to resolve. I'm doing Q-learning on a custom gym
environment, and I can run the following pieces successfully:
q = coax.Q(func_q, env)
pi = coax.Policy(func_pi, env)
qlearning = coax.td_learning.QLearning(q, pi_targ=pi, optimizer=optax.adam(0.001))
cache = coax.reward_tracing.NStep(n=1, gamma=0.9)
Additionally, my setup passes the simple checks of:
data = coax.Q.example_data(env) # Looks good
...
s = env.observation_space.sample()
a = env.action_space.sample()
print(q(s,a)) # 0.0
...
a = pi(s)
print(a) # [0, 0, 0, 0, 0] as I have a MultiDiscrete action space
However, once I get to actually running the training loop:
for ep in range(50):
pi.epsilon = 0.1
s = env.reset()
for t in range(env.maxGuesses):
a = pi(s)
s_next, r, done, info = env.step(a)
# update
cache.add(s, a, r, done)
while cache:
transition_batch = cache.pop()
metrics = qlearning.update(transition_batch)
env.record_metrics(metrics)
if done:
break
s = s_next
# early stopping
if env.avg_G > env.reward_threshold:
break
I get a bunch of errors with the most human-readable of them saying:
ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
By adjusting the parameters of the environment, I can adjust what the numbers that are mismatched are. I can't get them to match and either way that seems like the wrong solution as something more fundamental seems to be the issue.
For reference, here are my functions for q
and pi
:
def func_pi(S, is_training):
logits = hk.Sequential((
hk.Linear(30), jax.nn.relu,
hk.Linear(30), jax.nn.relu,
hk.Linear(30), jax.nn.relu,
hk.Linear(Wordle.wordLength*len(alphabet), w_init=jnp.zeros) # This many possible actions
))
# First, convert to a vector:
sVec = state_to_vec(S)
# Now get the output:
logitVec = logits(sVec)
# Now chunk the output into alphabet-sized pieces (definitionally an integral
# number of them). There will be Wordle.wordLength chunks of this length
chunks = jnp.split(logitVec, Wordle.wordLength)
# Now format our output array:
ret = []
for chunk in chunks:
ret.append({'logits': jnp.reshape(chunk,(1,len(alphabet)))})
return tuple(ret)
# and for actual state:
def func_q(S, A, is_training):
value = hk.Sequential((
hk.Linear(30), jax.nn.relu,
hk.Linear(30), jax.nn.relu,
hk.Linear(30), jax.nn.relu,
hk.Linear(1, w_init=jnp.zeros), jnp.ravel
))
sVec = state_to_vec(S)
aVec = action_to_vec(A)
X = jnp.concatenate((sVec, aVec))
return value(X)
Note that state_to_vec(S)
and action_to_vec(A)
just convert from my internal types to jnp.array
's for use with Haiku.
I'm quite new to coax/JAX/Haiku so it's entirely possible I've set something up wrong. For completeness here's the full text of the error:
Traceback (most recent call last):
File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
metrics = qlearning.update(transition_batch)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
return self._grads_and_metrics_func(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
out_flat = xla.xla_call(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
return call_bind(self, fun, *args, **params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 596, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1623, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 996, in grad_f_aux
(_, aux), g = value_and_grad_f(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 1067, in value_and_grad_f
ans, vjp_py, aux = _vjp(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 2478, in _vjp
out_primal, out_vjp, aux = ad.vjp(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 118, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 103, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 520, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
out_flat = xla.xla_call(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
return call_bind(self, fun, *args, **params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 324, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
return call_bind(self, fun, *args, **params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 204, in process_call
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 317, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
return call_bind(self, fun, *args, **params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1364, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
out = f(*args, **kwargs)
File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
return value(X)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
out = layer(out, *args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
metrics = qlearning.update(transition_batch)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
return self._grads_and_metrics_func(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
out = f(*args, **kwargs)
File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
return value(X)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
out = layer(out, *args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
raise ValueError(
ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
Please let me know if other information would be useful or relevant (or let me know if this isn't actually a coax
issue...).
Thanks for your help and the neat package.
bug good first issue question