diff --git a/main.py b/main.py
index a59a8b464e9c6cd64c543dc8adb2a8bdcf41d0cc..7eef356d171125bab10280363f2ecdecc547f607 100644
--- a/main.py
+++ b/main.py
@@ -1,95 +1,98 @@
 import argparse
-import safety_gymnasium
 import numpy as np
 import torch
 import random
 import os
 import json
 import datetime
-
+import safety_gymnasium
 from torch.utils.tensorboard import SummaryWriter
+
+from src.stats import Statistics
 from src.buffer import ReplayBuffer
 from src.policy import CSCAgent
-from src.stats import Statistics
 
 ##################
 # ARGPARSER
 ##################
 def cmd_args():
-    parser = argparse.ArgumentParser()
+    parser = argparse.ArgumentParser(formatter_class=lambda prog:argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=40))
     # environment args
-    parser.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID",
-                        help="Set the environment (default: SafetyPointGoal1-v0)")
-    parser.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N",
-                        help="Set a cost limit at which point an episode is considered unsafe (default: 25)")
-    parser.add_argument("--enforce_cost_limit", action="store_true", default=False,
-                        help="Aborts episode if cost limit is reached (default: False)")
-    parser.add_argument("--num_vectorized_envs", action="store", type=int, default=16, metavar="N",
-                        help="Sets the number of vectorized environments (default: 16)")
-
-    # train args
-    parser.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N",
-                        help="Number of episodes until policy optimization (default: 16)")
-    parser.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N",
-                        help="Perform evaluation after N * train_episodes episodes of training (default: 2)")
-    parser.add_argument("--update_iterations", action="store", type=int, default=3, metavar="N",
-                        help="Number of updates performed after each training step (default: 3)")
-    parser.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N",
-                        help="Number of episodes used for testing (default: 32)")
-    parser.add_argument("--total_steps", action="store", type=int, default=25_000_000, metavar="N",
-                        help="Total number of steps until training is finished (default: 25_000_000)")
-    parser.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N",
-                        help="Batch size used for training (default: 1024)")
-    parser.add_argument("--tau", action="store", type=float, default=0.05, metavar="N",
-                        help="Factor used in soft update of target network (default: 0.05)")
-    parser.add_argument("--shielded_action_sampling", action="store_true", default=False,
-                        help="Sample shielded actions when performing parameter updates (default: False)")
-
+    env_args = parser.add_argument_group('Environment')
+    env_args.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID",
+                        help="Set the environment")
+    env_args.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N",
+                        help="Set a cost limit/budget")
+    env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=16, metavar="N",
+                        help="Sets the number of vectorized environments")
+
+    # train and test args
+    train_test_args = parser.add_argument_group('Train and Test')
+    train_test_args.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N",
+                        help="Number of episodes until policy optimization")
+    train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N",
+                        help="Perform evaluation after N * total_train_episodes episodes of training")
+    train_test_args.add_argument("--update_iterations", action="store", type=int, default=3, metavar="N",
+                        help="Number of updates performed after each training step")
+    train_test_args.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N",
+                        help="Number of episodes used for testing")
+    train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N",
+                        help="Total number of steps until training is finished")
+    train_test_args.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N",
+                        help="Batch size used for training")
+    train_test_args.add_argument("--tau", action="store", type=float, default=0.05, metavar="N",
+                        help="Factor used in soft update of target network")
+    
     # buffer args
-    parser.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N",
-                        help="Define the maximum capacity of the replay buffer (default: 50_000)")
-    parser.add_argument("--clear_buffer", action="store_true", default=False,
-                        help="Clear Replay Buffer after every optimization step (default: False)")
+    buffer_args = parser.add_argument_group('Buffer')
+    buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N",
+                        help="Define the maximum capacity of the replay buffer")
+    buffer_args.add_argument("--clear_buffer", action="store_true", default=False,
+                        help="Clear Replay Buffer after every optimization step")
 
     # csc args
-    parser.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N",
-                        help="Maximum number of actions sampled during shielding (default: 100)")
-    parser.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N",
-                        help="Maximum number of line search update iterations (default: 20)")
-    parser.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N",
-                        help="Number of samples to estimate expectations (default: 20)")
-    parser.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N",
-                        help="Set the value of chi (default: 0.05)")
-    parser.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N",
-                        help="Set the value of delta (default: 0.01)")
-    parser.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N",
-                        help="Set the value of gamma (default: 0.99)")
-    parser.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N",
-                        help="Set the value of beta (default: 0.7)")
-    parser.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N",
-                        help="Set the value of alpha (default: 0.5)")
-    parser.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N",
-                        help="Set the initial value of lambda (default: 1.0)")
-    parser.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N",
-                        help="Learn rate for the safety critic (default: 2e-4)")
-    parser.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N",
-                        help="Learn rate for the value network (default: 1e-3)")
-    parser.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N",
-                        help="Learn rate for the lambda dual variable (default: 4e-2)")
-    parser.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N",
-                        help="Hidden dimension of the networks (default: 32)")
-    parser.add_argument("--sigmoid_activation", action="store_true", default=False,
-                        help="Apply sigmoid activation to the safety critics output (default: False)")
+    csc_args = parser.add_argument_group('Agent')
+    csc_args.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N",
+                        help="Maximum number of actions sampled during shielding")
+    csc_args.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N",
+                        help="Maximum number of line search update iterations")
+    csc_args.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N",
+                        help="Number of samples to estimate expectations")
+    csc_args.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N",
+                        help="Set the value of chi")
+    csc_args.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N",
+                        help="Set the value of delta")
+    csc_args.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N",
+                        help="Set the value of gamma")
+    csc_args.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N",
+                        help="Set the value of beta")
+    csc_args.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N",
+                        help="Set the value of alpha")
+    csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N",
+                        help="Set the initial value of lambda")
+    csc_args.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N",
+                        help="Learn rate for the safety critic")
+    csc_args.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N",
+                        help="Learn rate for the value network")
+    csc_args.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N",
+                        help="Learn rate for the lambda dual variable")
+    csc_args.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N",
+                        help="Hidden dimension of the networks")
+    csc_args.add_argument("--sigmoid_activation", action="store_true", default=False,
+                        help="Apply sigmoid activation to the safety critics output")
+    csc_args.add_argument("--shielded_action_sampling", action="store_true", default=False,
+                        help="Sample shielded actions when performing parameter updates")
 
     # common args
-    parser.add_argument("--seed", action="store", type=int, default=42, metavar="N",
-                        help="Set a custom seed for the rng (default: 42)")
-    parser.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE",
-                        help="Set the device for pytorch to use (default: cuda)")
-    parser.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH",
-                        help="Set the output and log directory path (default: ./runs)")
-    parser.add_argument("--num_threads", action="store", type=int, default=1, metavar="N",
-                        help="Set the maximum number of threads for pytorch and numpy (default: 1)")
+    common_args = parser.add_argument_group('Common')
+    common_args.add_argument("--seed", action="store", type=int, default=42, metavar="N",
+                        help="Set a custom seed for the rng")
+    common_args.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE",
+                        help="Set the device for pytorch to use")
+    common_args.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH",
+                        help="Set the output and log directory path")
+    common_args.add_argument("--num_threads", action="store", type=int, default=1, metavar="N",
+                        help="Set the maximum number of threads for pytorch and numpy")
     args = parser.parse_args()
     return args
 
@@ -98,147 +101,163 @@ def cmd_args():
 ##################
 
 def setup(args):
+    """
+    Performs setup like fixing seeds, initializing env and agent, buffer and stats.
+    """
     torch.set_num_threads(args.num_threads)
-
-    os.environ["OMP_NUM_THREADS"] = str(args.num_threads) # export OMP_NUM_THREADS=args.num_threads
-    os.environ["OPENBLAS_NUM_THREADS"] = str(args.num_threads) # export OPENBLAS_NUM_THREADS=args.num_threads
-    os.environ["MKL_NUM_THREADS"] = str(args.num_threads) # export MKL_NUM_THREADS=args.num_threads
-    os.environ["VECLIB_MAXIMUM_THREADS"] = str(args.num_threads) # export VECLIB_MAXIMUM_THREADS=args.num_threads
-    os.environ["NUMEXPR_NUM_THREADS"] = str(args.num_threads) # export NUMEXPR_NUM_THREADS=args.num_threads
-
     random.seed(args.seed)
     np.random.seed(args.seed)
     torch.manual_seed(args.seed)
-    torch.set_default_dtype(torch.float64)
 
     output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S"))
     writer = SummaryWriter(log_dir=output_dir)
     with open(os.path.join(output_dir, "config.json"), "w") as file:
         json.dump(args.__dict__, file, indent=2)
 
-    env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=True)
+    env = env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
     buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
-    agent = CSCAgent(env, args, writer)
+    stats = Statistics(writer)
+    agent = CSCAgent(env, args, buffer, stats)
 
-    stats = Statistics()
-
-    return env, agent, buffer, writer, stats
+    return env, agent, buffer, stats
 
 ##################
 # EXPLORATION
 ##################
 
 @torch.no_grad
-def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agent, buffer, stats, train, shielded):
-
-    avg_steps = 0
-    avg_reward = 0
-    avg_cost = 0
-    avg_failures = 0
-
-    episode_count = args.num_vectorized_envs
-    finished_count = 0
-    num_episodes = args.train_episodes if train else args.test_episodes
-
-    mask = np.ones(args.num_vectorized_envs, dtype='bool')
-    episode_steps = np.zeros(args.num_vectorized_envs)
-    episode_reward = np.zeros(args.num_vectorized_envs)
-    episode_cost = np.zeros(args.num_vectorized_envs)
-
-    if num_episodes < args.num_vectorized_envs:
-        mask[num_episodes:] = False
-        episode_count = num_episodes
+def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True):
+    # track currently running and leftover episodes
+    open_episodes = args.train_episodes if train else args.test_episodes
+    running_episodes = args.num_vectorized_envs
+
+    # initialize mask and stats per episode
+    mask = np.ones(args.num_vectorized_envs, dtype=np.bool_)
+    episode_steps = np.zeros_like(mask, dtype=np.uint64)
+    episode_reward = np.zeros_like(mask, dtype=np.float64)
+    episode_cost = np.zeros_like(mask, dtype=np.float64)
+
+    # adjust mask in case we have fewer runs than environments
+    if open_episodes < args.num_vectorized_envs:
+        mask[open_episodes:] = False
+        running_episodes = open_episodes
+    open_episodes -= running_episodes
 
     state, info = env.reset(seed=random.randint(0, 2**31-1))
-    while finished_count < num_episodes:
-        action = agent.sample(state, shielded=shielded)
-
-        next_state, reward, cost, terminated, truncated, info = env.step(action)
+    
+    while running_episodes > 0:
+        # sample and execute actions
+        with torch.no_grad():
+            actions = agent.sample(state, shielded=shielded).cpu().numpy()
+        next_state, reward, cost, terminated, truncated, info = env.step(actions)
         done = terminated | truncated
+        not_done_masked = ((~done) & mask)
+        done_masked = done & mask
+        done_masked_count = done_masked.sum()
 
+        # increment stats
         episode_steps[mask] += 1
         episode_reward[mask] += reward[mask]
         episode_cost[mask] += cost[mask]
 
-        is_mask_zero = ~mask.any()
-
-        if done.any() or is_mask_zero:
-            avg_steps += episode_steps.sum()
-            avg_reward += episode_reward.sum()
-            avg_cost += episode_cost.sum()
-            avg_failures += (episode_cost >= args.cost_limit).sum()
-
+        # if any run has finished, we need to take special care
+        # 1. train: extract final_observation from info dict (single envs autoreset, no manual reset needed) and add to buffer
+        # 2. train+test: log episode stats using Statistics class
+        # 3. train+test: reset episode stats
+        # 4. train+test: adjust mask (if necessary)
+        if done_masked_count > 0:
             if train:
-                stats.total_episodes += episode_count - finished_count
-                stats.total_steps += episode_steps.sum()
-                stats.total_failures += (episode_cost >= args.cost_limit).sum()
-
-                if not is_mask_zero:
-                    buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask])
-
-            mask = np.ones(args.num_vectorized_envs, dtype='bool')
-            episode_steps = np.zeros(args.num_vectorized_envs)
-            episode_reward = np.zeros(args.num_vectorized_envs)
-            episode_cost = np.zeros(args.num_vectorized_envs)
-            state, _ = env.reset()  # auto resets, but is_mask_zero requires us to reset
-
-            finished_count = episode_count
-            open_episodes = num_episodes - episode_count
-            idx = min(open_episodes, args.num_vectorized_envs)
-            mask[idx:] = False
-            episode_count += idx
-        
+                # add experiences to buffer
+                buffer.add(
+                    state[not_done_masked], 
+                    actions[not_done_masked], 
+                    reward[not_done_masked], 
+                    cost[not_done_masked], 
+                    next_state[not_done_masked],
+                    done[not_done_masked]
+                )
+                buffer.add(
+                    state[done_masked], 
+                    actions[done_masked], 
+                    reward[done_masked], 
+                    cost[done_masked], 
+                    np.stack(info['final_observation'], axis=0)[done_masked],
+                    done[done_masked]
+                )
+
+                # record finished episodes
+                stats.record_train(
+                    num_episodes=done_masked_count,
+                    returns=episode_reward[done_masked],
+                    costs=episode_cost[done_masked],
+                    steps=episode_steps[done_masked],
+                    unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
+                )
+                stats.total_train_episodes += done_masked_count
+                stats.total_train_steps += episode_steps[done_masked].sum()
+                stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum()
+            
+            else:
+                # record finished episodes
+                # stats module performs averaging over all episodes
+                stats.record_test(
+                    shielded=shielded,
+                    avg_returns=episode_reward[done_masked],
+                    avg_costs=episode_cost[done_masked],
+                    avg_steps=episode_steps[done_masked],
+                    avg_unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
+                )
+            
+            # reset episode stats
+            state = next_state
+            episode_steps[done_masked] = 0
+            episode_reward[done_masked] = 0
+            episode_cost[done_masked] = 0
+
+            # adjust mask, running and open episodes counter
+            if open_episodes < done_masked_count:  # fewer left than just finished
+                done_masked_idxs = done_masked.nonzero()[0]
+                mask[done_masked_idxs[open_episodes:]] = False
+                running_episodes -= (done_masked_count - open_episodes)
+                open_episodes = 0
+            else:   # at least as many left than just finished
+                open_episodes -= done_masked_count
+
+        # no run has finished, just record experiences (if training)
         else:
             if train:
-                buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask])
-                if args.enforce_cost_limit:   # we dont care about the cost limit while testing
-                    mask = mask & (episode_cost < args.cost_limit)
+                buffer.add(state[mask], actions[mask], reward[mask], cost[mask], next_state[mask], done[mask])
             state = next_state
     
-    avg_steps /= num_episodes
-    avg_reward /= num_episodes
-    avg_cost /= num_episodes
-    avg_failures /= num_episodes
-
-    return avg_steps, avg_reward, avg_cost, avg_failures
+    # after exploration, flush stats
+    stats.after_exploration(train, shielded)
 
 ##################
 # MAIN LOOP
 ##################
 
-def main(args, env, agent, buffer, writer, stats:Statistics):
+def main(args, env, agent, buffer, stats):
     finished = False
     while not finished:
+        # Training + Update Loop
         for _ in range(args.train_until_test):
-            if stats.total_steps >= args.total_steps:
-                finished = True
-                break
-            stats.begin(name="train")
-            avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True)
-            print(f"[TRAIN] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}")
-            stats.end(name="train")
-            stats.begin(name="update")
-            for i in range(args.update_iterations):
-                stats.total_updates += 1
-                agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=stats.total_episodes + i)
-            stats.end(name="update")
+
+            # 1. Run exploration for training
+            run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True)
+
+            # 2. Perform updates
+            for _ in range(args.update_iterations):
+                agent.update()
+
+            # 3. After update stuff
             agent.after_updates()
             if args.clear_buffer:
                 buffer.clear()
 
-        stats.begin(name="test")
-        for shielded, postfix in zip([True, False], ["shielded", "unshielded"]):
-            avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded)
-            writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, stats.total_episodes)
-            writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, stats.total_episodes)
-            writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, stats.total_episodes)
-            print(f"[TEST_{postfix.upper()}] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}")
-        stats.end(name="test")
-        stats.print()
-
-    writer.flush()
+        # Test loop (shielded and unshielded)
+        for shielded in [True, False]:
+            run_vectorized_exploration(args, env, agent, buffer, train=False, shielded=shielded)
 
 if __name__ == '__main__':
     args = cmd_args()
-    env, agent, buffer, writer, stats = setup(args)
-    main(args, env, agent, buffer, writer, stats)
\ No newline at end of file
+    main(args, *setup(args))
\ No newline at end of file