Looks like there is a problem with parameter pmap when using adam optimizer
This (note I have dropped --sr
)
export CUDA_VISIBLE_DEVICES=0,1,2,3
python main.py --n 13 --dim 2 --rs 10.0 --Theta 0.15 --Emax 25 --batch 4096 --num_devices 4 --acc_steps 2
gives
...
iter: 0001 F: -1.9656119959069707 F_std: 0.002970439785129169 E: -1.9296861940063403 E_std: 0.0029793946046783874 K: 0.2892473978205551 K_std: 0.00013518503689654573 V: -2.2189335918268953 V_std: 0.002970435066943954 S: 5.98763365010506 S_std: 0.022527996149788027 accept_rate: 0.6797216796875
Traceback (most recent call last):
File "/home/wanglei/CoulombGas/main.py", line 322, in <module>
keys, state_indices, x, accept_rate = sample_stateindices_and_x(keys,
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2058, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/api.py", line 1934, in f_pmapped
out = pxla.xla_pmap(
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 1727, in bind
return call_bind(self, fun, *args, **params)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 1652, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 1730, in process
return trace.process_map(self, fun, tracers, params)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/core.py", line 633, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 766, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/linear_util.py", line 263, in memoized_fun
ans = call(fun, *args)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 794, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 964, in lower_parallel_callable
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 871, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1566, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1543, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/wanglei/CoulombGas/VMC.py", line 22, in sample_stateindices_and_x
state_indices = sampler(params_van, key_state, batch)
File "/home/wanglei/CoulombGas/sampler.py", line 37, in sampler
logits = jax.vmap(_logits, (None, 0), 0)(params, state_indices)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/_src/api.py", line 1520, in batched_fun
out_flat = batching.batch(
File "/home/wanglei/.local/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/wanglei/CoulombGas/sampler.py", line 27, in _logits
logits = network.apply(params, None, sp_indices[state_idx])
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 400, in apply_fn
out = f(*args, **kwargs)
File "/home/wanglei/CoulombGas/main.py", line 94, in forward_fn
return model(state_idx)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/wanglei/CoulombGas/autoregressive.py", line 73, in __call__
x = hk.Linear(self.model_size,
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/basic.py", line 174, in __call__
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/base.py", line 331, in get_parameter
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'transformer/embedding_mlp/w' with retrieved shape (4, 4, 2, 16) does not match shape=[2, 16] dtype=dtype('int64')
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/wanglei/CoulombGas/main.py", line 322, in <module>
keys, state_indices, x, accept_rate = sample_stateindices_and_x(keys,
File "/home/wanglei/CoulombGas/VMC.py", line 22, in sample_stateindices_and_x
state_indices = sampler(params_van, key_state, batch)
File "/home/wanglei/CoulombGas/sampler.py", line 37, in sampler
logits = jax.vmap(_logits, (None, 0), 0)(params, state_indices)
File "/home/wanglei/CoulombGas/sampler.py", line 27, in _logits
logits = network.apply(params, None, sp_indices[state_idx])
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/transform.py", line 400, in apply_fn
out = f(*args, **kwargs)
File "/home/wanglei/CoulombGas/main.py", line 94, in forward_fn
return model(state_idx)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/wanglei/CoulombGas/autoregressive.py", line 73, in __call__
x = hk.Linear(self.model_size,
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/basic.py", line 174, in __call__
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
File "/home/wanglei/.local/lib/python3.9/site-packages/haiku/_src/base.py", line 331, in get_parameter
raise ValueError(
ValueError: 'transformer/embedding_mlp/w' with retrieved shape (4, 4, 2, 16) does not match shape=[2, 16] dtype=dtype('int64')