From 7ac2f5d318f8f531e4333fa2a8043bbec27e0a5d Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Mon, 30 Sep 2024 00:50:26 +0200
Subject: [PATCH] Bugfix enforce_cost_limit

---
 main.py | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/main.py b/main.py
index 83eee60..d01e3fc 100644
--- a/main.py
+++ b/main.py
@@ -120,7 +120,7 @@ def setup(args):
     return env, agent, buffer, writer
 
 ##################
-# USEFUL FUNCTIONS
+# EXPLORATION
 ##################
 
 @torch.no_grad
@@ -155,7 +155,9 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
         episode_reward[mask] += reward[mask]
         episode_cost[mask] += cost[mask]
 
-        if done.any() or ~mask.any():
+        is_mask_zero = ~mask.any()
+
+        if done.any() or is_mask_zero:
             avg_steps += episode_steps.sum()
             avg_reward += episode_reward.sum()
             avg_cost += episode_cost.sum()
@@ -165,13 +167,14 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
             total_steps += episode_steps.sum()
             total_failures += (episode_cost >= args.cost_limit).sum()
 
-            if train:
+            if train and not is_mask_zero:
                 buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask])
 
             mask = np.ones(args.num_vectorized_envs, dtype='bool')
             episode_steps = np.zeros(args.num_vectorized_envs)
             episode_reward = np.zeros(args.num_vectorized_envs)
             episode_cost = np.zeros(args.num_vectorized_envs)
+            state, _ = env.reset()  # auto resets, but is_mask_zero requires us to reset
 
             finished_count = episode_count
             open_episodes = num_episodes - episode_count
@@ -179,13 +182,12 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
             mask[idx:] = False
             episode_count += idx
         
-        elif train:
-            buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask])
-            if args.enforce_cost_limit:   # we dont care about the cost limit while testing
-                mask = mask and episode_cost < args.cost_limit
         else:
-            pass    # eval
-        state = next_state
+            if train:
+                buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask])
+                if args.enforce_cost_limit:   # we dont care about the cost limit while testing
+                    mask = mask & (episode_cost < args.cost_limit)
+            state = next_state
     
     avg_steps /= num_episodes
     avg_reward /= num_episodes
-- 
GitLab