Hi,
after installation the "CPU part" (jackhammer and hhblits) work well.
But when i start the gpu part, i've got this error message:
TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.
1st part: ./run_feature.sh -d data -o ./tmp -m model_1,model_2,model_3,model_4,model_5 -f ./query/1crn.fasta -t 2021-07-27
2st part: ./run_alphafold.sh -d data -o ./tmp -m model_1,model_2,model_3,model_4,model_5 -f ./query/1crn.fasta -t 2021-07-27
Full error message:
File "/softwares/alphafold/run_alphafold.py", line 316, in
app.run(main)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/softwares/alphafold/run_alphafold.py", line 289, in main
predict_structure(
File "/softwares/alphafold/run_alphafold.py", line 188, in predict_structure
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
File "/softwares/alphafold/alphafold/relax/relax.py", line 58, in process
out = amber_minimize.run_pipeline(
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 482, in run_pipeline
ret.update(get_violation_metrics(prot))
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 356, in get_violation_metrics
structural_violations, struct_metrics = find_violations(prot)
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 338, in find_violations
violations = folding.find_structural_violations(
File "/softwares/alphafold/alphafold/model/folding.py", line 757, in find_structural_violations
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
File "/softwares/alphafold/alphafold/model/utils.py", line 39, in batched_gather
return take_fn(params, indices)
File "/softwares/alphafold/alphafold/model/utils.py", line 36, in
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5383, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/api.py", line 411, in cache_miss
out_flat = xla.xla_call(
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1618, in bind
return call_bind(self, fun, *args, **params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1609, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1621, in process
return trace.process_call(self, fun, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 615, in process_call
return primitive.impl(f, *tracers, **params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 622, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 694, in _xla_callable
return lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs).compile().unsafe_call
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 702, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1522, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1500, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5390, in _take
_check_arraylike("take", a)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 559, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/softwares/alphafold/run_alphafold.py", line 316, in
app.run(main)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/softwares/alphafold/run_alphafold.py", line 289, in main
predict_structure(
File "/softwares/alphafold/run_alphafold.py", line 188, in predict_structure
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
File "/softwares/alphafold/alphafold/relax/relax.py", line 58, in process
out = amber_minimize.run_pipeline(
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 482, in run_pipeline
ret.update(get_violation_metrics(prot))
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 356, in get_violation_metrics
structural_violations, struct_metrics = find_violations(prot)
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 338, in find_violations
violations = folding.find_structural_violations(
File "/softwares/alphafold/alphafold/model/folding.py", line 757, in find_structural_violations
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
File "/softwares/alphafold/alphafold/model/utils.py", line 39, in batched_gather
return take_fn(params, indices)
File "/softwares/alphafold/alphafold/model/utils.py", line 36, in
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5383, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5390, in _take
_check_arraylike("take", a)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 559, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.