From 7dffa991951feae4e1f633e584a1b27cbf3b70a5 Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Thu, 20 Feb 2025 14:37:24 +0100
Subject: [PATCH] Added stats

---
 src/stats.py | 111 ++++++++++++++++++++++++++++++---------------------
 1 file changed, 66 insertions(+), 45 deletions(-)

diff --git a/src/stats.py b/src/stats.py
index bd886d9..d584f00 100644
--- a/src/stats.py
+++ b/src/stats.py
@@ -1,52 +1,73 @@
-import time
+import numpy as np
 
-class Statistics():    
-    def __init__(self):
-        self.total_episodes = 0
-        self.total_failures = 0
-        self.total_steps = 0
-        self.total_updates = 0
-
-        for name in ['train', 'update', 'test']:
-            setattr(self, f"time_avg_{name}", 0)
-            setattr(self, f"{name}_count", 0)
-            setattr(self, f"_time_{name}", None)
-
-        self._time_start = time.time()
+class Statistics:
+    def __init__(self, writer):
+        self.writer = writer
 
+        self.total_train_episodes = 0
+        self.total_train_steps = 0
+        self.total_train_unsafe = 0
+        self.avg_train_unsafe = 0
 
-    def total_time(self, name=None):
-        if name is None: 
-            t = time.time() - self._time_start
-        else: 
-            t = getattr(self, f"time_avg_{name}") * getattr(self, f"{name}_count")
-        return round(t,2)
-
-
-    def begin(self, name):
-        setattr(self, f"_time_{name}", time.time())
+        self.total_updates = 0
 
+        self.history = {
+            "train" : {
+                "_episodes" : [],
+                "returns" : [],
+                "costs" : [],
+                "steps" : [],
+                "unsafe" : [],
+            },
+            "test/unshielded" : {
+                "_episodes" : [],
+                "avg_returns" : [],
+                "avg_costs" : [],
+                "avg_steps" : [],
+                "avg_unsafe" : [],
+            },
+            "test/shielded" : {
+                "_episodes" : [],
+                "avg_returns" : [],
+                "avg_costs" : [],
+                "avg_steps" : [],
+                "avg_unsafe" : [],
+            },
+        }
+    
+    def history_clear(self, name):
+        for data in self.history[name].values():
+            data.clear()
 
-    def end(self, name):
-        t = time.time()
-        d = t - getattr(self, f"_time_{name}")
-        setattr(self, f"{name}_count", getattr(self, f"{name}_count") + 1)
-        c = getattr(self, f"{name}_count")
-        a = 1/c
-        b = 1 - a
-        setattr(self, f"time_avg_{name}", a*d + b*getattr(self, f"time_avg_{name}"))
+    def history_flush(self, name):
+        for key, values in self.history[name].items():
+            if key.startswith("_"): continue
+            path = f"{name}/{key}"
+            for i in range(len(values)):
+                self.writer.add_scalar(path, values[i], self.history[name]["_episodes"][i])
+        self.history_clear(name)
 
+    def _record(self, name, **kwargs):
+        for key in self.history[name].keys():
+            self.history[name][key] += kwargs[key].tolist()
+    
+    def record_train(self, num_episodes, **kwargs):
+        episodes = np.arange(0, num_episodes, 1) + self.total_train_episodes
+        kwargs["_episodes"] = episodes
+        self._record("train", **kwargs)
 
-    def print(self):
-        time = self.total_time()
-        time_train = self.total_time("train")
-        time_update = self.total_time("update")
-        time_test = self.total_time("test")
-        steps = self.total_steps
-        episodes = self.total_episodes
-        failures = self.total_failures
-        updates = self.total_updates
-        avg_failures = round((self.total_failures/self.total_episodes)*100,2) if self.total_episodes > 0 else 0.00
-        
-        print(f"[TOTAL] total_time: {time}s, total_train: {time_train}s, total_update: {time_update}s, total_test: {time_test}s,\n \
-       steps: {steps}, episodes: {episodes}, failures: {failures}, avg_failures: {avg_failures}%, updates: {updates}")
\ No newline at end of file
+    def record_test(self, shielded, **kwargs):
+        shielded = "shielded" if shielded else "unshielded"
+        kwargs["_episodes"] = np.full_like(kwargs["avg_unsafe"], self.total_train_episodes, dtype=np.uint64)
+        self._record(f"test/{shielded}", **kwargs)
+    
+    def after_exploration(self, train, shielded):
+        if train:
+            self.avg_train_unsafe = np.array(self.history["train"]["unsafe"]).mean().item()
+            self.history_flush(name="train")
+        else:
+            shielded = "shielded" if shielded else "unshielded"
+            name = f"test/{shielded}"
+            for key, value in self.history[name].items():
+                self.history[name][key] = np.array(value).mean().tolist()
+            self.history_flush(name=name)
\ No newline at end of file
-- 
GitLab