From 5608549803fa81526f4ff1603b290035a3468d3e Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Mon, 30 Sep 2024 17:00:29 +0200
Subject: [PATCH] Bugfix eval total counter

---
 main.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/main.py b/main.py
index 259f40f..ed05316 100644
--- a/main.py
+++ b/main.py
@@ -156,12 +156,13 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
             avg_cost += episode_cost.sum()
             avg_failures += (episode_cost >= args.cost_limit).sum()
 
-            stats.total_episodes += episode_count - finished_count
-            stats.total_steps += episode_steps.sum()
-            stats.total_failures += (episode_cost >= args.cost_limit).sum()
+            if train:
+                stats.total_episodes += episode_count - finished_count
+                stats.total_steps += episode_steps.sum()
+                stats.total_failures += (episode_cost >= args.cost_limit).sum()
 
-            if train and not is_mask_zero:
-                buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask])
+                if not is_mask_zero:
+                    buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask])
 
             mask = np.ones(args.num_vectorized_envs, dtype='bool')
             episode_steps = np.zeros(args.num_vectorized_envs)
-- 
GitLab