diff --git a/main.py b/main.py
index 83eee6094bafa62276d76d7cf4945bd4c05fb4ae..d01e3fc16c6ac2bc06e836f4058b6ab070b10660 100644
--- a/main.py
+++ b/main.py
@@ -120,7 +120,7 @@ def setup(args):
     return env, agent, buffer, writer
 
 ##################
-# USEFUL FUNCTIONS
+# EXPLORATION
 ##################
 
 @torch.no_grad
@@ -155,7 +155,9 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
         episode_reward[mask] += reward[mask]
         episode_cost[mask] += cost[mask]
 
-        if done.any() or ~mask.any():
+        is_mask_zero = ~mask.any()
+
+        if done.any() or is_mask_zero:
             avg_steps += episode_steps.sum()
             avg_reward += episode_reward.sum()
             avg_cost += episode_cost.sum()
@@ -165,13 +167,14 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
             total_steps += episode_steps.sum()
             total_failures += (episode_cost >= args.cost_limit).sum()
 
-            if train:
+            if train and not is_mask_zero:
                 buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask])
 
             mask = np.ones(args.num_vectorized_envs, dtype='bool')
             episode_steps = np.zeros(args.num_vectorized_envs)
             episode_reward = np.zeros(args.num_vectorized_envs)
             episode_cost = np.zeros(args.num_vectorized_envs)
+            state, _ = env.reset()  # auto resets, but is_mask_zero requires us to reset
 
             finished_count = episode_count
             open_episodes = num_episodes - episode_count
@@ -179,13 +182,12 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
             mask[idx:] = False
             episode_count += idx
         
-        elif train:
-            buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask])
-            if args.enforce_cost_limit:   # we dont care about the cost limit while testing
-                mask = mask and episode_cost < args.cost_limit
         else:
-            pass    # eval
-        state = next_state
+            if train:
+                buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask])
+                if args.enforce_cost_limit:   # we dont care about the cost limit while testing
+                    mask = mask & (episode_cost < args.cost_limit)
+            state = next_state
     
     avg_steps /= num_episodes
     avg_reward /= num_episodes