Skip to content
Snippets Groups Projects
Unverified Commit 23646987 authored by Haonan Yu's avatar Haonan Yu Committed by GitHub
Browse files

fix suite_babyai TimeLimit wrapper (#936)

parent 7acb5d56
No related branches found
No related tags found
No related merge requests found
......@@ -68,12 +68,15 @@ def load(environment_name,
Returns:
An AlfEnvironment instance.
"""
gym_spec = gym.spec(environment_name)
gym_env = gym_spec.make()
# babyai doesn't register max_episode_steps in Gym
gym_env = gym.make(environment_name)
# but it does have an interal property ``max_steps`` which returns ``done=True```
# when reached, see
# https://github.com/maximecb/gym-minigrid/blob/6f5fe8588d05eb13a08f971fd3c7a82c404dc1bb/gym_minigrid/minigrid.py#L1158
if max_episode_steps is None:
if gym_spec.max_episode_steps is not None:
max_episode_steps = gym_spec.max_episode_steps
if hasattr(gym_env, 'max_steps'):
# minus 1 because we need to let ALF wrap ``TimeLimit`` before getting ``done=True```
max_episode_steps = gym_env.max_steps - 1
else:
max_episode_steps = 0
......
......@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import alf
from alf.environments import suite_babyai
from alf.environments.alf_wrappers import TimeLimit
class SuiteBabyAITest(alf.test.TestCase):
......@@ -80,6 +82,24 @@ class SuiteBabyAITest(alf.test.TestCase):
np.alltrue(obs['mission'] == instr1)
or np.alltrue(obs['mission'] == instr2))
def test_timelimit_discount(self):
env_name = "BabyAI-GoToObj-v0"
gym_env = gym.make(env_name)
gym_spec = gym.spec(env_name)
self.assertTrue(gym_spec.max_episode_steps is None)
# first test the original env will incorrectly return done=True when timeout
self.assertTrue(hasattr(gym_env, 'max_steps'))
gym_env.reset()
for i in range(gym_env.max_steps):
observation, reward, done, info = gym_env.step(0)
self.assertTrue(done) # timelimit
# then test the new suite_babyai will correctly handle this
env = suite_babyai.load(env_name)
self.assertTrue(isinstance(env, TimeLimit))
self.assertEqual(env.duration, gym_env.max_steps - 1)
if __name__ == '__main__':
alf.test.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment