From 17ca6a32c91a45d741fb99422161e814ea9e9a88 Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Mon, 30 Sep 2024 13:56:55 +0200 Subject: [PATCH] Added more detailed statistics --- main.py | 55 ++++++++++++++++++++++++------------------------ src/stats.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 28 deletions(-) create mode 100644 src/stats.py diff --git a/main.py b/main.py index d01e3fc..943b15c 100644 --- a/main.py +++ b/main.py @@ -10,15 +10,7 @@ import datetime from torch.utils.tensorboard import SummaryWriter from src.buffer import ReplayBuffer from src.policy import CSCAgent - -################## -# GLOBALS -################## - -total_episodes = 0 -total_failures = 0 -total_steps = 0 -total_updates = 0 +from src.stats import Statistics ################## # ARGPARSER @@ -117,15 +109,17 @@ def setup(args): buffer = ReplayBuffer(env=env, cap=args.buffer_capacity) agent = CSCAgent(env, args, writer) - return env, agent, buffer, writer + stats = Statistics() + + return env, agent, buffer, writer, stats ################## # EXPLORATION ################## @torch.no_grad -def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agent, buffer, writer, train, shielded): - global total_episodes, total_steps, total_failures +def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agent, buffer, stats, train, shielded): + avg_steps = 0 avg_reward = 0 avg_cost = 0 @@ -163,9 +157,9 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen avg_cost += episode_cost.sum() avg_failures += (episode_cost >= args.cost_limit).sum() - total_episodes += episode_count - finished_count - total_steps += episode_steps.sum() - total_failures += (episode_cost >= args.cost_limit).sum() + stats.total_episodes += episode_count - finished_count + stats.total_steps += episode_steps.sum() + stats.total_failures += (episode_cost >= args.cost_limit).sum() if train and not is_mask_zero: buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask]) @@ -200,33 +194,38 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen # MAIN LOOP ################## -def main(args, env, agent, buffer, writer): - global total_episodes, total_steps, total_failures, total_updates +def main(args, env, agent, buffer, writer, stats:Statistics): finished = False while not finished: for _ in range(args.train_until_test): - if total_steps >= args.total_steps: + if stats.total_steps >= args.total_steps: finished = True break - avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, writer, train=True, shielded=True) + stats.begin(name="train") + avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True) print(f"[TRAIN] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}") + stats.end(name="train") + stats.begin(name="update") for i in range(args.update_iterations): - total_updates += 1 - agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=total_episodes) + stats.total_updates += 1 + agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=stats.total_episodes + i) + stats.end(name="update") if args.clear_buffer: buffer.clear() - print(f"[TOTAL] episodes: {total_episodes}, steps: {total_steps}, failures: {total_failures}, updates: {total_updates}") + stats.begin(name="test") for shielded, postfix in zip([True, False], ["shielded", "unshielded"]): - avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, writer, train=False, shielded=shielded) - writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, total_episodes) - writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, total_episodes) - writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, total_episodes) + avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded) + writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, stats.total_episodes) + writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, stats.total_episodes) + writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, stats.total_episodes) print(f"[TEST_{postfix.upper()}] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}") + stats.end(name="test") + stats.print() writer.flush() if __name__ == '__main__': args = cmd_args() - env, agent, buffer, writer = setup(args) - main(args, env, agent, buffer, writer) \ No newline at end of file + env, agent, buffer, writer, stats = setup(args) + main(args, env, agent, buffer, writer, stats) \ No newline at end of file diff --git a/src/stats.py b/src/stats.py new file mode 100644 index 0000000..cd59501 --- /dev/null +++ b/src/stats.py @@ -0,0 +1,59 @@ +import time + +class Statistics(): + def __init__(self): + self.total_episodes = 0 + self.total_failures = 0 + self.total_steps = 0 + self.total_updates = 0 + + self.time_avg_train = 0 + self.train_count = 0 + self._time_train = None + + self.time_avg_update = 0 + self.update_count = 0 + self._time_update = None + + self.time_avg_test = 0 + self.test_count = 0 + self._time_test = None + + self.time_start = time.time() + + + def total_time(self, name=None): + if name is None: + t = time.time() - self.time_start + else: + t = self.__dict__[f"time_avg_{name}"] * self.__dict__[f"{name}_count"] + return round(t,2) + + + def begin(self, name): + self.__dict__[f"_time_{name}"] = time.time() + + + def end(self, name): + t = time.time() + d = t - self.__dict__[f"_time_{name}"] + self.__dict__[f"{name}_count"] += 1 + c = self.__dict__[f"{name}_count"] + a = 1/c + b = 1 - a + self.__dict__[f"time_avg_{name}"] = a*d + b*self.__dict__[f"time_avg_{name}"] + + + def print(self): + time = self.total_time() + time_train = self.total_time("train") + time_update = self.total_time("update") + time_test = self.total_time("test") + steps = self.total_steps + episodes = self.total_episodes + failures = self.total_failures + updates = self.total_updates + avg_failures = round((self.total_failures/self.total_episodes)*100,2) if self.total_episodes > 0 else 0.00 + + print(f"[TOTAL] total_time: {time}s, total_train: {time_train}s, total_update: {time_update}s, total_test: {time_test}s,\n \ + steps: {steps}, episodes: {episodes}, failures: {failures}, avg_failures: {avg_failures}%, updates: {updates}") \ No newline at end of file -- GitLab