Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to slove it? Indexer must have integer or boolean type, #13

Open
strivebfq opened this issue Dec 26, 2023 · 0 comments
Open

How to slove it? Indexer must have integer or boolean type, #13

strivebfq opened this issue Dec 26, 2023 · 0 comments

Comments

@strivebfq
Copy link

(gig) fuqing@vcis13:~/test$ python 1
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Traceback (most recent call last):
File "/home/fuqing/test/1", line 13, in
obs, state, rewards, dones, ep_done = env.step(state, action, key_step)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/type_enforced/enforcer.py", line 148, in get_fn
return self.call(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/type_enforced/enforcer.py", line 181, in call
return_value = self.fn(*args, **kwargs)
File "/home/fuqing/test/gigastep-main/gigastep/gigastep_env.py", line 362, in step
agent_states = v_step(agent_states, actions, self._per_agent_thrust)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/api.py", line 1258, in vmap_f
out_flat = batching.batch(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/api.py", line 317, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/fuqing/test/gigastep-main/gigastep/gigastep_env.py", line 311, in _step_agents
action = self.action_lut[action].reshape(3).astype(jnp.float32)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/array.py", line 336, in getitem
return lax_numpy._rewriting_take(self, idx)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4494, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4503, in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4763, in _index_to_gather
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 0, indexer value Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=2/0)>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant