From 7dffa991951feae4e1f633e584a1b27cbf3b70a5 Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Thu, 20 Feb 2025 14:37:24 +0100 Subject: [PATCH] Added stats --- src/stats.py | 111 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 66 insertions(+), 45 deletions(-) diff --git a/src/stats.py b/src/stats.py index bd886d9..d584f00 100644 --- a/src/stats.py +++ b/src/stats.py @@ -1,52 +1,73 @@ -import time +import numpy as np -class Statistics(): - def __init__(self): - self.total_episodes = 0 - self.total_failures = 0 - self.total_steps = 0 - self.total_updates = 0 - - for name in ['train', 'update', 'test']: - setattr(self, f"time_avg_{name}", 0) - setattr(self, f"{name}_count", 0) - setattr(self, f"_time_{name}", None) - - self._time_start = time.time() +class Statistics: + def __init__(self, writer): + self.writer = writer + self.total_train_episodes = 0 + self.total_train_steps = 0 + self.total_train_unsafe = 0 + self.avg_train_unsafe = 0 - def total_time(self, name=None): - if name is None: - t = time.time() - self._time_start - else: - t = getattr(self, f"time_avg_{name}") * getattr(self, f"{name}_count") - return round(t,2) - - - def begin(self, name): - setattr(self, f"_time_{name}", time.time()) + self.total_updates = 0 + self.history = { + "train" : { + "_episodes" : [], + "returns" : [], + "costs" : [], + "steps" : [], + "unsafe" : [], + }, + "test/unshielded" : { + "_episodes" : [], + "avg_returns" : [], + "avg_costs" : [], + "avg_steps" : [], + "avg_unsafe" : [], + }, + "test/shielded" : { + "_episodes" : [], + "avg_returns" : [], + "avg_costs" : [], + "avg_steps" : [], + "avg_unsafe" : [], + }, + } + + def history_clear(self, name): + for data in self.history[name].values(): + data.clear() - def end(self, name): - t = time.time() - d = t - getattr(self, f"_time_{name}") - setattr(self, f"{name}_count", getattr(self, f"{name}_count") + 1) - c = getattr(self, f"{name}_count") - a = 1/c - b = 1 - a - setattr(self, f"time_avg_{name}", a*d + b*getattr(self, f"time_avg_{name}")) + def history_flush(self, name): + for key, values in self.history[name].items(): + if key.startswith("_"): continue + path = f"{name}/{key}" + for i in range(len(values)): + self.writer.add_scalar(path, values[i], self.history[name]["_episodes"][i]) + self.history_clear(name) + def _record(self, name, **kwargs): + for key in self.history[name].keys(): + self.history[name][key] += kwargs[key].tolist() + + def record_train(self, num_episodes, **kwargs): + episodes = np.arange(0, num_episodes, 1) + self.total_train_episodes + kwargs["_episodes"] = episodes + self._record("train", **kwargs) - def print(self): - time = self.total_time() - time_train = self.total_time("train") - time_update = self.total_time("update") - time_test = self.total_time("test") - steps = self.total_steps - episodes = self.total_episodes - failures = self.total_failures - updates = self.total_updates - avg_failures = round((self.total_failures/self.total_episodes)*100,2) if self.total_episodes > 0 else 0.00 - - print(f"[TOTAL] total_time: {time}s, total_train: {time_train}s, total_update: {time_update}s, total_test: {time_test}s,\n \ - steps: {steps}, episodes: {episodes}, failures: {failures}, avg_failures: {avg_failures}%, updates: {updates}") \ No newline at end of file + def record_test(self, shielded, **kwargs): + shielded = "shielded" if shielded else "unshielded" + kwargs["_episodes"] = np.full_like(kwargs["avg_unsafe"], self.total_train_episodes, dtype=np.uint64) + self._record(f"test/{shielded}", **kwargs) + + def after_exploration(self, train, shielded): + if train: + self.avg_train_unsafe = np.array(self.history["train"]["unsafe"]).mean().item() + self.history_flush(name="train") + else: + shielded = "shielded" if shielded else "unshielded" + name = f"test/{shielded}" + for key, value in self.history[name].items(): + self.history[name][key] = np.array(value).mean().tolist() + self.history_flush(name=name) \ No newline at end of file -- GitLab