From 17ca6a32c91a45d741fb99422161e814ea9e9a88 Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Mon, 30 Sep 2024 13:56:55 +0200
Subject: [PATCH] Added more detailed statistics

---
 main.py      | 55 ++++++++++++++++++++++++------------------------
 src/stats.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 86 insertions(+), 28 deletions(-)
 create mode 100644 src/stats.py

diff --git a/main.py b/main.py
index d01e3fc..943b15c 100644
--- a/main.py
+++ b/main.py
@@ -10,15 +10,7 @@ import datetime
 from torch.utils.tensorboard import SummaryWriter
 from src.buffer import ReplayBuffer
 from src.policy import CSCAgent
-
-##################
-# GLOBALS
-##################
-
-total_episodes = 0
-total_failures = 0
-total_steps = 0
-total_updates = 0
+from src.stats import Statistics
 
 ##################
 # ARGPARSER
@@ -117,15 +109,17 @@ def setup(args):
     buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
     agent = CSCAgent(env, args, writer)
 
-    return env, agent, buffer, writer
+    stats = Statistics()
+
+    return env, agent, buffer, writer, stats
 
 ##################
 # EXPLORATION
 ##################
 
 @torch.no_grad
-def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agent, buffer, writer, train, shielded):
-    global total_episodes, total_steps, total_failures
+def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agent, buffer, stats, train, shielded):
+
     avg_steps = 0
     avg_reward = 0
     avg_cost = 0
@@ -163,9 +157,9 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
             avg_cost += episode_cost.sum()
             avg_failures += (episode_cost >= args.cost_limit).sum()
 
-            total_episodes += episode_count - finished_count
-            total_steps += episode_steps.sum()
-            total_failures += (episode_cost >= args.cost_limit).sum()
+            stats.total_episodes += episode_count - finished_count
+            stats.total_steps += episode_steps.sum()
+            stats.total_failures += (episode_cost >= args.cost_limit).sum()
 
             if train and not is_mask_zero:
                 buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask])
@@ -200,33 +194,38 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
 # MAIN LOOP
 ##################
 
-def main(args, env, agent, buffer, writer):
-    global total_episodes, total_steps, total_failures, total_updates
+def main(args, env, agent, buffer, writer, stats:Statistics):
     finished = False
     while not finished:
         for _ in range(args.train_until_test):
-            if total_steps >= args.total_steps:
+            if stats.total_steps >= args.total_steps:
                 finished = True
                 break
-            avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, writer, train=True, shielded=True)
+            stats.begin(name="train")
+            avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True)
             print(f"[TRAIN] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}")
+            stats.end(name="train")
+            stats.begin(name="update")
             for i in range(args.update_iterations):
-                total_updates += 1
-                agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=total_episodes)
+                stats.total_updates += 1
+                agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=stats.total_episodes + i)
+            stats.end(name="update")
             if args.clear_buffer:
                 buffer.clear()
-        print(f"[TOTAL] episodes: {total_episodes}, steps: {total_steps}, failures: {total_failures}, updates: {total_updates}")
 
+        stats.begin(name="test")
         for shielded, postfix in zip([True, False], ["shielded", "unshielded"]):
-            avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, writer, train=False, shielded=shielded)
-            writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, total_episodes)
-            writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, total_episodes)
-            writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, total_episodes)
+            avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded)
+            writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, stats.total_episodes)
+            writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, stats.total_episodes)
+            writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, stats.total_episodes)
             print(f"[TEST_{postfix.upper()}] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}")
+        stats.end(name="test")
+        stats.print()
 
     writer.flush()
 
 if __name__ == '__main__':
     args = cmd_args()
-    env, agent, buffer, writer = setup(args)
-    main(args, env, agent, buffer, writer)
\ No newline at end of file
+    env, agent, buffer, writer, stats = setup(args)
+    main(args, env, agent, buffer, writer, stats)
\ No newline at end of file
diff --git a/src/stats.py b/src/stats.py
new file mode 100644
index 0000000..cd59501
--- /dev/null
+++ b/src/stats.py
@@ -0,0 +1,59 @@
+import time
+
+class Statistics():    
+    def __init__(self):
+        self.total_episodes = 0
+        self.total_failures = 0
+        self.total_steps = 0
+        self.total_updates = 0
+        
+        self.time_avg_train = 0
+        self.train_count = 0
+        self._time_train = None
+
+        self.time_avg_update = 0
+        self.update_count = 0
+        self._time_update = None
+
+        self.time_avg_test = 0
+        self.test_count = 0
+        self._time_test = None
+
+        self.time_start = time.time()
+
+
+    def total_time(self, name=None):
+        if name is None: 
+            t = time.time() - self.time_start
+        else: 
+            t = self.__dict__[f"time_avg_{name}"] * self.__dict__[f"{name}_count"]
+        return round(t,2)
+
+
+    def begin(self, name):
+        self.__dict__[f"_time_{name}"] = time.time()
+
+
+    def end(self, name):
+        t = time.time()
+        d = t - self.__dict__[f"_time_{name}"]
+        self.__dict__[f"{name}_count"] += 1
+        c = self.__dict__[f"{name}_count"]
+        a = 1/c
+        b = 1 - a
+        self.__dict__[f"time_avg_{name}"] = a*d + b*self.__dict__[f"time_avg_{name}"]
+
+
+    def print(self):
+        time = self.total_time()
+        time_train = self.total_time("train")
+        time_update = self.total_time("update")
+        time_test = self.total_time("test")
+        steps = self.total_steps
+        episodes = self.total_episodes
+        failures = self.total_failures
+        updates = self.total_updates
+        avg_failures = round((self.total_failures/self.total_episodes)*100,2) if self.total_episodes > 0 else 0.00
+        
+        print(f"[TOTAL] total_time: {time}s, total_train: {time_train}s, total_update: {time_update}s, total_test: {time_test}s,\n \
+       steps: {steps}, episodes: {episodes}, failures: {failures}, avg_failures: {avg_failures}%, updates: {updates}")
\ No newline at end of file
-- 
GitLab