Skip to content
Snippets Groups Projects
Commit 7dffa991 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Added stats

parent 69b76df7
Branches
No related tags found
No related merge requests found
import time
import numpy as np
class Statistics():
def __init__(self):
self.total_episodes = 0
self.total_failures = 0
self.total_steps = 0
self.total_updates = 0
for name in ['train', 'update', 'test']:
setattr(self, f"time_avg_{name}", 0)
setattr(self, f"{name}_count", 0)
setattr(self, f"_time_{name}", None)
self._time_start = time.time()
class Statistics:
def __init__(self, writer):
self.writer = writer
self.total_train_episodes = 0
self.total_train_steps = 0
self.total_train_unsafe = 0
self.avg_train_unsafe = 0
def total_time(self, name=None):
if name is None:
t = time.time() - self._time_start
else:
t = getattr(self, f"time_avg_{name}") * getattr(self, f"{name}_count")
return round(t,2)
def begin(self, name):
setattr(self, f"_time_{name}", time.time())
self.total_updates = 0
self.history = {
"train" : {
"_episodes" : [],
"returns" : [],
"costs" : [],
"steps" : [],
"unsafe" : [],
},
"test/unshielded" : {
"_episodes" : [],
"avg_returns" : [],
"avg_costs" : [],
"avg_steps" : [],
"avg_unsafe" : [],
},
"test/shielded" : {
"_episodes" : [],
"avg_returns" : [],
"avg_costs" : [],
"avg_steps" : [],
"avg_unsafe" : [],
},
}
def history_clear(self, name):
for data in self.history[name].values():
data.clear()
def end(self, name):
t = time.time()
d = t - getattr(self, f"_time_{name}")
setattr(self, f"{name}_count", getattr(self, f"{name}_count") + 1)
c = getattr(self, f"{name}_count")
a = 1/c
b = 1 - a
setattr(self, f"time_avg_{name}", a*d + b*getattr(self, f"time_avg_{name}"))
def history_flush(self, name):
for key, values in self.history[name].items():
if key.startswith("_"): continue
path = f"{name}/{key}"
for i in range(len(values)):
self.writer.add_scalar(path, values[i], self.history[name]["_episodes"][i])
self.history_clear(name)
def _record(self, name, **kwargs):
for key in self.history[name].keys():
self.history[name][key] += kwargs[key].tolist()
def record_train(self, num_episodes, **kwargs):
episodes = np.arange(0, num_episodes, 1) + self.total_train_episodes
kwargs["_episodes"] = episodes
self._record("train", **kwargs)
def print(self):
time = self.total_time()
time_train = self.total_time("train")
time_update = self.total_time("update")
time_test = self.total_time("test")
steps = self.total_steps
episodes = self.total_episodes
failures = self.total_failures
updates = self.total_updates
avg_failures = round((self.total_failures/self.total_episodes)*100,2) if self.total_episodes > 0 else 0.00
print(f"[TOTAL] total_time: {time}s, total_train: {time_train}s, total_update: {time_update}s, total_test: {time_test}s,\n \
steps: {steps}, episodes: {episodes}, failures: {failures}, avg_failures: {avg_failures}%, updates: {updates}")
\ No newline at end of file
def record_test(self, shielded, **kwargs):
shielded = "shielded" if shielded else "unshielded"
kwargs["_episodes"] = np.full_like(kwargs["avg_unsafe"], self.total_train_episodes, dtype=np.uint64)
self._record(f"test/{shielded}", **kwargs)
def after_exploration(self, train, shielded):
if train:
self.avg_train_unsafe = np.array(self.history["train"]["unsafe"]).mean().item()
self.history_flush(name="train")
else:
shielded = "shielded" if shielded else "unshielded"
name = f"test/{shielded}"
for key, value in self.history[name].items():
self.history[name][key] = np.array(value).mean().tolist()
self.history_flush(name=name)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment