diff --git a/main.py b/main.py
index d9d5b7c0eb05053b4bbe887549ca9274f02fec99..a32e903157287163fb07253f67282d58e4646247 100644
--- a/main.py
+++ b/main.py
@@ -5,12 +5,12 @@ import random
 import os
 import json
 import datetime
-import safety_gymnasium
 from torch.utils.tensorboard import SummaryWriter
 
 from src.stats import Statistics
 from src.buffer import ReplayBuffer
 from src.cql_sac.agent import CQLSAC
+from src.environment import create_environment
 
 ##################
 # ARGPARSER
@@ -96,6 +96,7 @@ def cmd_args():
                         help="Set the output and log directory path")
     common_args.add_argument("--num_threads", action="store", type=int, default=1, metavar="N",
                         help="Set the maximum number of threads for pytorch and numpy")
+
     args = parser.parse_args()
     return args
 
@@ -117,9 +118,9 @@ def setup(args):
     with open(os.path.join(output_dir, "config.json"), "w") as file:
         json.dump(args.__dict__, file, indent=2)
 
-    env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
+    env = create_environment(args=args)
     buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
-    stats = Statistics(writer)
+    stats = Statistics(writer=writer)
     agent = CQLSAC(env=env, args=args, stats=stats)
 
     return env, agent, buffer, stats
@@ -129,7 +130,7 @@ def setup(args):
 ##################
 
 @torch.no_grad
-def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True):
+def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train=True, shielded=True):
     # track currently running and leftover episodes
     open_episodes = args.train_episodes if train else args.test_episodes
     running_episodes = args.num_vectorized_envs
@@ -188,26 +189,24 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie
                 )
 
                 # record finished episodes
-                stats.record_train(
-                    num_episodes=done_masked_count,
-                    returns=episode_reward[done_masked],
-                    costs=episode_cost[done_masked],
-                    steps=episode_steps[done_masked],
-                    unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
-                )
+                ticks = stats.total_train_steps + np.cumsum(episode_steps[done_masked], axis=0)
+                stats.log_tensorboard("train/returns", episode_reward[done_masked], ticks)
+                stats.log_tensorboard("train/costs", episode_cost[done_masked], ticks)
+                stats.log_tensorboard("train/steps", episode_steps[done_masked], ticks)
+                stats.log_tensorboard("train/unsafe", (episode_cost[done_masked] > args.cost_limit).astype(np.uint8), ticks)
+
                 stats.total_train_episodes += done_masked_count
                 stats.total_train_steps += episode_steps[done_masked].sum()
                 stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum()
             
             else:
                 # record finished episodes
-                # stats module performs averaging over all episodes
-                stats.record_test(
-                    shielded=shielded,
-                    avg_returns=episode_reward[done_masked],
-                    avg_costs=episode_cost[done_masked],
-                    avg_steps=episode_steps[done_masked],
-                    avg_unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
+                # stats module performs averaging over all when logging episodes
+                stats.log_test_history(
+                    steps=episode_steps[done_masked],
+                    reward=episode_reward[done_masked],
+                    cost=episode_cost[done_masked],
+                    unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
                 )
             
             # reset episode stats
@@ -230,9 +229,59 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie
             if train:
                 buffer.add(state[mask], actions[mask], reward[mask], cost[mask], next_state[mask], done[mask])
             state = next_state
-    
-    # after exploration, flush stats
-    stats.after_exploration(train, shielded)
+
+    # average and log finished test episodes
+    if not train:
+        stats.log_test_tensorboard(f"test/{'shielded' if shielded else 'unshielded'}", stats.total_train_steps)
+
+def single_exploration(args, env, agent, buffer, stats, train=True, shielded=True):
+    # NOTE: Unused function
+    state, info = env.reset(seed=random.randint(0, 2**31-1))
+    episode_reward = 0
+    episode_cost = 0
+    episode_steps = 0
+    done = False
+
+    while not done:
+        with torch.no_grad():
+            action = agent.get_action(np.expand_dims(state, 0), eval=not train).cpu().numpy().squeeze(0)
+        next_state, reward, cost, terminated, truncated, info = env.step(action)
+        done = terminated or truncated
+        episode_reward += reward
+        episode_cost += cost
+        episode_steps += 1
+
+        state = np.expand_dims(state, 0)
+        next_state = np.expand_dims(next_state, 0)
+        action = np.expand_dims(action, 0)
+        reward = np.array([reward])
+        cost = np.array([cost])
+
+        if train:
+            buffer.add(state, action, reward, cost, next_state, np.array([terminated]))
+            if stats.total_train_steps >= 10_000:
+                x = agent.learn(experiences=buffer.sample(n=args.batch_size))
+                actor_loss, alpha_loss, critic1_loss, critic2_loss, cql1_scaled_loss, cql2_scaled_loss, current_alpha, cql_alpha_loss, cql_alpha = x
+                stats.writer.add_scalar("debug/actor_loss", actor_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/alpha_loss", alpha_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/critic1_loss", critic1_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/critic2_loss", critic2_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/cql1_scaled_loss", cql1_scaled_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/cql2_scaled_loss", cql2_scaled_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/current_alpha", current_alpha, stats.total_updates)
+                stats.writer.add_scalar("debug/cql_alpha_loss", cql_alpha_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/cql_alpha", cql_alpha, stats.total_updates)
+                stats.total_updates += 1
+
+        state = next_state.squeeze(0)
+        
+    stats.writer.add_scalar("train/returns", episode_reward, stats.total_train_steps)
+    stats.writer.add_scalar("train/costs", episode_cost, stats.total_train_steps)
+    stats.writer.add_scalar("train/steps", episode_steps, stats.total_train_steps)
+
+    stats.total_train_episodes += 1
+    stats.total_train_steps += episode_steps
+    stats.total_train_unsafe += int(episode_cost > args.cost_limit)
 
 ##################
 # MAIN LOOP
@@ -257,7 +306,7 @@ def main(args, env, agent, buffer, stats):
 
         # Test loop (shielded and unshielded)
         for shielded in [True, False]:
-            run_vectorized_exploration(args, env, agent, buffer, train=False, shielded=shielded)
+            run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded)
 
 if __name__ == '__main__':
     args = cmd_args()
diff --git a/src/environment.py b/src/environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfb89717d076740298985b5eccfa8703497562c7
--- /dev/null
+++ b/src/environment.py
@@ -0,0 +1,29 @@
+import safety_gymnasium
+import gymnasium
+import numpy as np
+
+class Gymnasium2SafetyGymnasium(gymnasium.Wrapper):
+    def step(self, action):
+        state, reward, terminated, truncated, info = super().step(action)
+        if 'cost' in info:
+            cost = info['cost']
+        elif isinstance(reward, (int, float)):
+            cost = 0
+        elif isinstance(reward, np.ndarray):
+            cost = np.zeros_like(reward)
+        elif isinstance(reward, list):
+            cost = [0]*len(reward)
+        else:
+            raise NotImplementedError("reward type not recognized") # for now
+        return state, reward, cost, terminated, truncated, info
+
+def create_environment(args):
+    if args.env_id.startswith("Safety"):
+        env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
+    if args.env_id.startswith("Gymnasium_"):
+        id = args.env_id[len('Gymnasium_'):]
+        env = gymnasium.make_vec(id, num_envs=args.num_vectorized_envs, vectorization_mode="sync")
+        env = Gymnasium2SafetyGymnasium(env)
+    if args.env_id.startswith("RaceTrack"):
+        raise NotImplementedError("RaceTrack environment is not implemented yet.")
+    return env    
\ No newline at end of file
diff --git a/src/stats.py b/src/stats.py
index d584f005b169b92ac69cc2697b56f4f20c8f904d..8b7494dec164ecaff2614582eb49da9c343b9aef 100644
--- a/src/stats.py
+++ b/src/stats.py
@@ -4,70 +4,52 @@ class Statistics:
     def __init__(self, writer):
         self.writer = writer
 
+        # Total training statistics
         self.total_train_episodes = 0
         self.total_train_steps = 0
         self.total_train_unsafe = 0
-        self.avg_train_unsafe = 0
-
         self.total_updates = 0
 
-        self.history = {
-            "train" : {
-                "_episodes" : [],
-                "returns" : [],
-                "costs" : [],
-                "steps" : [],
-                "unsafe" : [],
-            },
-            "test/unshielded" : {
-                "_episodes" : [],
-                "avg_returns" : [],
-                "avg_costs" : [],
-                "avg_steps" : [],
-                "avg_unsafe" : [],
-            },
-            "test/shielded" : {
-                "_episodes" : [],
-                "avg_returns" : [],
-                "avg_costs" : [],
-                "avg_steps" : [],
-                "avg_unsafe" : [],
-            },
-        }
-    
-    def history_clear(self, name):
-        for data in self.history[name].values():
-            data.clear()
+        # Used for calculating average unsafe of previous training episodes
+        self.train_unsafe_history = []
+        self._train_unsafe_avg = 0
 
-    def history_flush(self, name):
-        for key, values in self.history[name].items():
-            if key.startswith("_"): continue
-            path = f"{name}/{key}"
-            for i in range(len(values)):
-                self.writer.add_scalar(path, values[i], self.history[name]["_episodes"][i])
-        self.history_clear(name)
+        # Used for averaging test results
+        self.test_steps_history = []
+        self.test_reward_history = []
+        self.test_cost_history = []
+        self.test_unsafe_history = []
 
-    def _record(self, name, **kwargs):
-        for key in self.history[name].keys():
-            self.history[name][key] += kwargs[key].tolist()
+    @property
+    def train_unsafe_avg(self):
+        if len(self.train_unsafe_history) > 0:
+            self._train_unsafe_avg = sum(self.train_unsafe_history) / len(self.train_unsafe_history)
+            self.train_unsafe_history.clear()
+        return self._train_unsafe_avg
     
-    def record_train(self, num_episodes, **kwargs):
-        episodes = np.arange(0, num_episodes, 1) + self.total_train_episodes
-        kwargs["_episodes"] = episodes
-        self._record("train", **kwargs)
-
-    def record_test(self, shielded, **kwargs):
-        shielded = "shielded" if shielded else "unshielded"
-        kwargs["_episodes"] = np.full_like(kwargs["avg_unsafe"], self.total_train_episodes, dtype=np.uint64)
-        self._record(f"test/{shielded}", **kwargs)
+    def log_train_history(self, unsafe: np.ndarray):
+        self.train_unsafe_history += unsafe.tolist()
     
-    def after_exploration(self, train, shielded):
-        if train:
-            self.avg_train_unsafe = np.array(self.history["train"]["unsafe"]).mean().item()
-            self.history_flush(name="train")
+    def log_test_history(self, steps: np.ndarray, reward: np.ndarray, cost: np.ndarray, unsafe: np.ndarray):
+        self.test_steps_history += steps.tolist()
+        self.test_reward_history += reward.tolist()
+        self.test_cost_history += cost.tolist()
+        self.test_unsafe_history += unsafe.tolist()
+    
+    def log_test_tensorboard(self, name_prefix, ticks):
+        self.log_tensorboard(name_prefix + '/avg_steps', np.array(self.test_steps_history).mean(), ticks)
+        self.log_tensorboard(name_prefix + '/avg_returns', np.array(self.test_reward_history).mean(), ticks)
+        self.log_tensorboard(name_prefix + '/avg_costs', np.array(self.test_cost_history).mean(), ticks)
+        self.log_tensorboard(name_prefix + '/avg_unsafes', np.array(self.test_unsafe_history).mean(), ticks)
+
+        self.test_steps_history.clear()
+        self.test_reward_history.clear()
+        self.test_cost_history.clear()
+        self.test_unsafe_history.clear()
+
+    def log_tensorboard(self, name: str, values: np.ndarray, ticks: np.ndarray):
+        if values.size == 1:
+            self.writer.add_scalar(name, values.item(), ticks.item())
         else:
-            shielded = "shielded" if shielded else "unshielded"
-            name = f"test/{shielded}"
-            for key, value in self.history[name].items():
-                self.history[name][key] = np.array(value).mean().tolist()
-            self.history_flush(name=name)
\ No newline at end of file
+            for v, t in zip(values, ticks):
+                self.writer.add_scalar(name, v, t)
\ No newline at end of file