You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
(gig) fuqing@vcis13:~/test$ python 1
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Traceback (most recent call last):
File "/home/fuqing/test/1", line 13, in
obs, state, rewards, dones, ep_done = env.step(state, action, key_step)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/type_enforced/enforcer.py", line 148, in get_fn
return self.call(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/type_enforced/enforcer.py", line 181, in call
return_value = self.fn(*args, **kwargs)
File "/home/fuqing/test/gigastep-main/gigastep/gigastep_env.py", line 362, in step
agent_states = v_step(agent_states, actions, self._per_agent_thrust)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/api.py", line 1258, in vmap_f
out_flat = batching.batch(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/api.py", line 317, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/fuqing/test/gigastep-main/gigastep/gigastep_env.py", line 311, in _step_agents
action = self.action_lut[action].reshape(3).astype(jnp.float32)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/array.py", line 336, in getitem
return lax_numpy._rewriting_take(self, idx)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4494, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4503, in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4763, in _index_to_gather
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 0, indexer value Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=2/0)>
The text was updated successfully, but these errors were encountered:
(gig) fuqing@vcis13:~/test$ python 1
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Traceback (most recent call last):
File "/home/fuqing/test/1", line 13, in
obs, state, rewards, dones, ep_done = env.step(state, action, key_step)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/type_enforced/enforcer.py", line 148, in get_fn
return self.call(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/type_enforced/enforcer.py", line 181, in call
return_value = self.fn(*args, **kwargs)
File "/home/fuqing/test/gigastep-main/gigastep/gigastep_env.py", line 362, in step
agent_states = v_step(agent_states, actions, self._per_agent_thrust)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/api.py", line 1258, in vmap_f
out_flat = batching.batch(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/api.py", line 317, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/fuqing/test/gigastep-main/gigastep/gigastep_env.py", line 311, in _step_agents
action = self.action_lut[action].reshape(3).astype(jnp.float32)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/array.py", line 336, in getitem
return lax_numpy._rewriting_take(self, idx)
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4494, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4503, in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
File "/home/fuqing/miniconda3/envs/gig/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4763, in _index_to_gather
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 0, indexer value Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=2/0)>
The text was updated successfully, but these errors were encountered: