Skip to content
Snippets Groups Projects
Unverified Commit bbc666d2 authored by Haichao Zhang's avatar Haichao Zhang Committed by GitHub
Browse files

Add flatten to create_environment; fix time_step stack issue (#1425)

parent 47117679
No related branches found
No related tags found
No related merge requests found
......@@ -182,7 +182,7 @@ class ParallelAlfEnvironment(alf_environment.AlfEnvironment):
self._time_step_with_env_info_spec, *time_steps)
else:
stacked = nest.fast_map_structure(
lambda *arrays: torch.stack(arrays), *time_steps)
lambda *arrays: numpy.stack(arrays), *time_steps)
stacked = nest.map_structure(
lambda x: torch.as_tensor(x, device='cpu'), stacked)
if alf.get_default_device() == "cuda":
......
......@@ -61,6 +61,7 @@ def create_environment(env_name='CartPole-v0',
env_load_fn=suite_gym.load,
num_parallel_environments=30,
nonparallel=False,
flatten=True,
seed=None,
batched_wrappers=()):
"""Create a batched environment.
......@@ -77,6 +78,8 @@ def create_environment(env_name='CartPole-v0',
num_parallel_environments (int): num of parallel environments
nonparallel (bool): force to create a single env in the current
process. Used for correctly exposing game gin confs to tensorboard.
flatten (bool): whether to use flatten action and time_steps during
communication to reduce overhead.
seed (None|int): random number seed for environment. A random seed is
used if None.
batched_wrappers (Iterable): a list of wrappers which can wrap batched
......@@ -119,7 +122,7 @@ def create_environment(env_name='CartPole-v0',
alf_env = parallel_environment.ParallelAlfEnvironment(
[functools.partial(env_load_fn, env_name)] *
num_parallel_environments,
flatten=True)
flatten=flatten)
if seed is None:
alf_env.seed([
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment