diff --git a/main.py b/main.py
index bf05dfb64a1147af9a4fac7ab988b081b365c0af..23f6cb2eeba13787d243bcdd8d56caad9f9d1b77 100644
--- a/main.py
+++ b/main.py
@@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter
 from src.stats import Statistics
 from src.buffer import ReplayBuffer
 from src.cql_sac.agent import CSCCQLSAC
-from src.environment import create_environment
+from src.environment import create_env
 
 ##################
 # ARGPARSER
@@ -21,20 +21,20 @@ def cmd_args():
     env_args = parser.add_argument_group('Environment')
     env_args.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID",
                         help="Set the environment")
-    env_args.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N",
+    env_args.add_argument('--env_args', type=str, default='',
+                    help='Environment specific arguments')
+    env_args.add_argument("--cost_limit", action="store", type=float, default=0.0, metavar="N",
                         help="Set a cost limit/budget")
-    env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=8, metavar="N",
-                        help="Sets the number of vectorized environments")
 
     # train and test args
     train_test_args = parser.add_argument_group('Train and Test')
-    train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N",
+    train_test_args.add_argument("--total_train_steps", action="store", type=int, default=5_000_000, metavar="N",
                     help="Total number of steps until training is finished")
-    train_test_args.add_argument("--train_episodes", action="store", type=int, default=8, metavar="N",
+    train_test_args.add_argument("--train_episodes", action="store", type=int, default=4, metavar="N",
                         help="Number of episodes until policy optimization")
-    train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N",
+    train_test_args.add_argument("--train_until_test", action="store", type=int, default=4, metavar="N",
                         help="Perform evaluation after N * total_train_episodes episodes of training")
-    train_test_args.add_argument("--update_iterations", action="store", type=int, default=128, metavar="N",
+    train_test_args.add_argument("--update_iterations", action="store", type=int, default=64, metavar="N",
                         help="Number of updates performed after each training step")
     train_test_args.add_argument("--test_episodes", action="store", type=int, default=16, metavar="N",
                         help="Number of episodes used for testing")
@@ -43,16 +43,16 @@ def cmd_args():
     update_args = parser.add_argument_group('Update')
     update_args.add_argument("--batch_size", action="store", type=int, default=256, metavar="N",
                         help="Batch size used for training")
-    update_args.add_argument("--tau", action="store", type=float, default=5e-3, metavar="N",
+    update_args.add_argument("--tau", action="store", type=float, default=0.005, metavar="N",
                         help="Factor used in soft update of target network")
     update_args.add_argument("--gamma", action="store", type=float, default=0.95, metavar="N",
                         help="Discount factor for rewards")
-    update_args.add_argument("--learning_rate", action="store", type=float, default=3e-4, metavar="N",
+    update_args.add_argument("--learning_rate", action="store", type=float, default=0.0003, metavar="N",
                         help="Learn rate for the policy and Q networks")
     
     # buffer args
     buffer_args = parser.add_argument_group('Buffer')
-    buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=100_000, metavar="N",
+    buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=250_000, metavar="N",
                         help="Define the maximum capacity of the replay buffer")
     buffer_args.add_argument("--clear_buffer", action="store_true", default=False,
                         help="Clear Replay Buffer after every optimization step")
@@ -88,7 +88,7 @@ def cmd_args():
 
     # common args
     common_args = parser.add_argument_group('Common')
-    common_args.add_argument("--seed", action="store", type=int, default=42, metavar="N",
+    common_args.add_argument("--seed", action="store", type=int, default=0, metavar="N",
                         help="Set a custom seed for the rng")
     common_args.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE",
                         help="Set the device for pytorch to use")
@@ -118,7 +118,7 @@ def setup(args):
     with open(os.path.join(output_dir, "config.json"), "w") as file:
         json.dump(args.__dict__, file, indent=2)
 
-    env = create_environment(args=args)
+    env = create_env(args=args)
     buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
     stats = Statistics(writer=writer)
     agent = CSCCQLSAC(env=env, args=args, stats=stats)
@@ -130,111 +130,85 @@ def setup(args):
 ##################
 
 @torch.no_grad
-def run_vectorized_exploration(args, env, agent:CSCCQLSAC, buffer:ReplayBuffer, stats:Statistics, train=True, shielded=True):
-    # track currently running and leftover episodes
-    open_episodes = args.train_episodes if train else args.test_episodes
-    running_episodes = args.num_vectorized_envs
-
-    # initialize mask and stats per episode
-    mask = np.ones(args.num_vectorized_envs, dtype=np.bool_)
-    episode_steps = np.zeros_like(mask, dtype=np.uint64)
-    episode_reward = np.zeros_like(mask, dtype=np.float64)
-    episode_cost = np.zeros_like(mask, dtype=np.float64)
-
-    # adjust mask in case we have fewer runs than environments
-    if open_episodes < args.num_vectorized_envs:
-        mask[open_episodes:] = False
-        running_episodes = open_episodes
-    open_episodes -= running_episodes
-
-    state, info = env.reset(seed=random.randint(0, 2**31-1))
+def run_single_exploration(args, env, agent:CSCCQLSAC, buffer:ReplayBuffer, stats:Statistics, train=True, shielded=True):
+    # track leftover episodes
+    num_episodes = args.train_episodes if train else args.test_episodes
+
+    avg_steps = 0.
+    avg_return = 0.
+    avg_cost = 0.
+    avg_unsafe = 0.
     
-    while running_episodes > 0:
-        # sample and execute actions
-        actions = agent.get_action(state, shielded=shielded).cpu().numpy()
-        next_state, reward, cost, terminated, truncated, info = env.step(actions)
-        done = terminated | truncated
-        not_done_masked = ((~done) & mask)
-        done_masked = done & mask
-        done_masked_count = done_masked.sum()
-
-        # increment stats
-        episode_steps[mask] += 1
-        episode_reward[mask] += reward[mask]
-        episode_cost[mask] += cost[mask]
-
-        # if any run has finished, we need to take special care
-        # 1. train: extract final_observation from info dict (single envs autoreset, no manual reset needed) and add to buffer
-        # 2. train+test: log episode stats using Statistics class
-        # 3. train+test: reset episode stats
-        # 4. train+test: adjust mask (if necessary)
-        if done_masked_count > 0:
+    for _ in range(num_episodes):
+        # initialize stats per episode
+        episode_steps = 0
+        episode_reward = 0.
+        episode_cost = 0.
+        done = False
+
+        # reset env
+        state, info = env.reset(seed=random.randint(0, 2**31-1))
+
+        while not done:
+            # sample and execute actions
+            action = agent.get_action(state, shielded=shielded).cpu().numpy()
+            next_state, reward, cost, terminated, truncated, info = env.step(action)
+            done = terminated or truncated
+
+            # increment stats
+            episode_steps += 1
+            episode_reward += reward
+            episode_cost += cost
+
+            # During training, add experiences to buffer
             if train:
-                # add experiences to buffer
-                buffer.add(
-                    state[not_done_masked], 
-                    actions[not_done_masked], 
-                    reward[not_done_masked], 
-                    cost[not_done_masked], 
-                    next_state[not_done_masked],
-                    terminated[not_done_masked]
-                )
                 buffer.add(
-                    state[done_masked], 
-                    actions[done_masked], 
-                    reward[done_masked], 
-                    cost[done_masked], 
-                    np.stack(info['final_observation'], axis=0)[done_masked],
-                    terminated[done_masked]
+                    state, action, reward, cost, next_state, not truncated
                 )
-
-                # record finished episodes
-                ticks = stats.total_train_steps + np.cumsum(episode_steps[done_masked], axis=0)
-                stats.log_tensorboard("train/returns", episode_reward[done_masked], ticks)
-                stats.log_tensorboard("train/costs", episode_cost[done_masked], ticks)
-                stats.log_tensorboard("train/steps", episode_steps[done_masked], ticks)
-                stats.log_tensorboard("train/unsafe", (episode_cost[done_masked] > args.cost_limit).astype(np.uint8), ticks)
-
-                stats.log_train_history((episode_cost[done_masked] > args.cost_limit).astype(np.uint8))
-
-                stats.total_train_episodes += done_masked_count
-                stats.total_train_steps += episode_steps[done_masked].sum()
-                stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum()
             
-            else:
-                # record finished episodes
-                # stats module performs averaging over all when logging episodes
-                stats.log_test_history(
-                    steps=episode_steps[done_masked],
-                    reward=episode_reward[done_masked],
-                    cost=episode_cost[done_masked],
-                    unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
-                )
+            # End episode if the cost exceeded the limit
+            if episode_cost > args.cost_limit:
+                done = True
             
-            # reset episode stats
-            state = next_state
-            episode_steps[done_masked] = 0
-            episode_reward[done_masked] = 0
-            episode_cost[done_masked] = 0
-
-            # adjust mask, running and open episodes counter
-            if open_episodes < done_masked_count:  # fewer left than just finished
-                done_masked_idxs = done_masked.nonzero()[0]
-                mask[done_masked_idxs[open_episodes:]] = False
-                running_episodes -= (done_masked_count - open_episodes)
-                open_episodes = 0
-            else:   # at least as many left than just finished
-                open_episodes -= done_masked_count
-
-        # no run has finished, just record experiences (if training)
-        else:
-            if train:
-                buffer.add(state[mask], actions[mask], reward[mask], cost[mask], next_state[mask], done[mask])
+            # Update state
             state = next_state
 
-    # average and log finished test episodes
-    if not train:
-        stats.log_test_tensorboard(f"test/{'shielded' if shielded else 'unshielded'}", stats.total_train_steps)
+        # Update statistics
+        count_unsafe = int(episode_cost > args.cost_limit)
+        stats.total_train_episodes += 1
+        stats.total_train_steps += episode_steps
+        stats.total_train_unsafe += count_unsafe
+
+        # record finished episodes
+        stats.log_train_tensorboard(
+            episode_steps=episode_steps,
+            episode_return=episode_reward,
+            episode_cost=episode_cost
+        )
+        
+        avg_return += episode_reward
+        avg_steps += episode_steps
+        avg_cost += episode_cost
+        avg_unsafe += count_unsafe
+
+    # average and log stats
+    avg_return /= num_episodes
+    avg_steps /= num_episodes
+    avg_cost /= num_episodes
+    avg_unsafe /= num_episodes
+
+    # In training, update avg_unsafe
+    # In testing, log averages
+    if train:
+        stats.train_unsafe_avg = avg_unsafe
+    else:
+        stats.log_test_tensorboard(
+            avg_steps=avg_steps,
+            avg_return=avg_return,
+            avg_cost=avg_cost,
+            avg_unsafe=avg_unsafe,
+            shielded=shielded
+        )
 
 ##################
 # MAIN LOOP
@@ -252,11 +226,12 @@ def main(args, env, agent, buffer, stats):
                 break
 
             # 1. Run exploration for training
-            run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True)
+            run_single_exploration(args, env, agent, buffer, stats, train=True, shielded=True)
 
             # 2. Perform updates
             for _ in range(args.update_iterations):
-                agent.learn(experiences=buffer.sample(n=args.batch_size))
+                experiences=buffer.sample(n=args.batch_size)
+                agent.learn(experiences=experiences)
 
             # 3. After update stuff
             if args.clear_buffer:
@@ -264,7 +239,7 @@ def main(args, env, agent, buffer, stats):
 
         # Test loop (shielded and unshielded)
         for shielded in [True, False]:
-            run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded)
+            run_single_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded)
 
 if __name__ == '__main__':
     args = cmd_args()
diff --git a/src/buffer.py b/src/buffer.py
index 6b72adbb5bc8264b4db5d9144b4f9bfabb3e686d..c1c839a3760b28e89b6d8b40ea8e7020846e3f7d 100644
--- a/src/buffer.py
+++ b/src/buffer.py
@@ -2,54 +2,37 @@ import numpy as np
 
 class ReplayBuffer():
     """
-    Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. Handles batched experiences.
+    Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer.
     """
     def __init__(self, env, cap):
         self._cap = max(1,cap)
         self._size = 0              # number of experiences in the buffer
         self._ptr = 0               # pointer to the next available slot in the buffer
 
-        self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64)
-        self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64)
+        self._states = np.zeros((cap, env.observation_space.shape[0]), dtype=np.float64)
+        self._actions = np.zeros((cap, env.action_space.shape[0]), dtype=np.float64)
         self._rewards = np.zeros((cap, ), dtype=np.float64)
         self._costs = np.zeros((cap, ), dtype=np.float64)
         self._next_states = np.zeros_like(self._states)
         self._dones = np.zeros((cap, ), dtype=np.uint8)
 
-
-    def _add(self, state, action, reward, cost, next_state, done, start, end):
-        self._states[start:end] = state
-        self._actions[start:end] = action
-        self._rewards[start:end] = reward
-        self._costs[start:end] = cost
-        self._next_states[start:end] = next_state
-        self._dones[start:end] = done
-
+    def _add(self, state, action, reward, cost, next_state, done):
+        idx = self._ptr
+        self._states[idx] = state
+        self._actions[idx] = action
+        self._rewards[idx] = reward
+        self._costs[idx] = cost
+        self._next_states[idx] = next_state
+        self._dones[idx] = done
 
     def add(self, state, action, reward, cost, next_state, done):
         """
-        Adds experiences to the buffer. Assumes batched experiences.
+        Adds an experience to the buffer.
         """
-        n = state.shape[0]          # NOTE: n should be less than or equal to the buffer capacity
-        idx_start = self._ptr
-        idx_end = self._ptr + n
-
-        # if the buffer has capacity, add the experiences to the end of the buffer
-        if idx_end <= self._cap:
-            self._add(state, action, reward, cost, next_state, done, idx_start, idx_end)
-        
-        # if the buffer does not have capacity, add the experiences to the end of the buffer and wrap around
-        else:
-            k = self._cap - idx_start
-            idx_end = n - k
-            self._add(state[:k], action[:k], reward[:k], cost[:k], next_state[:k], done[:k], start=idx_start, end=self._cap)
-            self._add(state[k:], action[k:], reward[k:], cost[k:], next_state[k:], done[k:], start=0, end=idx_end)
-        
-        # update the buffer size and pointer 
-        self._ptr = idx_end
-        if self._size < self._cap:
-            self._size = min(self._cap, self._size + n)
-
+        self._add(state, action, reward, cost, next_state, done)
+        self._ptr += 1
+        self._size = max(self._size, self._ptr)
+        self._ptr = self._ptr % self._cap   
 
     def sample(self, n):
         """
@@ -59,7 +42,6 @@ class ReplayBuffer():
         return self._states[idxs], self._actions[idxs], self._rewards[idxs], \
             self._costs[idxs], self._next_states[idxs], self._dones[idxs]
     
-
     def clear(self):
         """
         Clears the buffer.
diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py
index 6f10681932d821e82505abbafb11d6896a4d88d4..8ed70543be081f0330cc7138af70dca3e5b23ab2 100644
--- a/src/cql_sac/agent.py
+++ b/src/cql_sac/agent.py
@@ -4,9 +4,7 @@ import torch.nn.functional as F
 import torch.nn as nn
 from torch.nn.utils import clip_grad_norm_
 from .networks import Critic, Actor
-import numpy as np
 import math
-import copy
 
 
 class CSCCQLSAC(nn.Module):
@@ -27,8 +25,8 @@ class CSCCQLSAC(nn.Module):
         super(CSCCQLSAC, self).__init__()
         self.stats = stats
 
-        state_size = env.observation_space.shape[-1]
-        action_size = env.action_space.shape[-1]
+        state_size = env.observation_space.shape[0]
+        action_size = env.action_space.shape[0]
         hidden_size = args.hidden_size
         self.action_size = action_size
 
@@ -62,7 +60,7 @@ class CSCCQLSAC(nn.Module):
         self.csc_avg_unsafe = args.csc_chi
 
         self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device)
-        self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) 
+        self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate)
         
         # Actor Network 
         self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device)
@@ -79,7 +77,7 @@ class CSCCQLSAC(nn.Module):
         self.safety_critic2_target.load_state_dict(self.safety_critic2.state_dict())
 
         self.safety_critic1_optimizer = optim.Adam(self.safety_critic1.parameters(), lr=self.learning_rate)
-        self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate) 
+        self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate)
         
         # Critic Network (w/ Target Network)
         self.critic1 = Critic(state_size, action_size, hidden_size).to(self.device)
@@ -92,46 +90,43 @@ class CSCCQLSAC(nn.Module):
         self.critic2_target.load_state_dict(self.critic2.state_dict())
 
         self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=self.learning_rate)
-        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate) 
+        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate)
 
     def get_action(self, state, shielded=True):
         """ 
         Returns shielded actions for given state as per current policy. 
         """
-        state = torch.from_numpy(state).float().to(self.device)
+        state = torch.from_numpy(state).float().to(self.device).unsqueeze(0)
         
         if shielded:
-            # Repeat state, resulting shape: (shield_iterations, batch_size, state_size)
-            state = state.repeat((self.csc_shield_iterations, 1)).reshape(self.csc_shield_iterations, *state.shape)
+            # Repeat state, resulting shape: (shield_iterations, state_size)
+            state = state.repeat((self.csc_shield_iterations, 1))
             unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe)
             
-            # Sample all 'csc_shield_iterations' actions at once for every state
+            # Sample all 'csc_shield_iterations' actions at once
             with torch.no_grad():
                 action = self.actor_local.get_action(state).to(self.device)
             
             # Estimate unsafety of all actions
             q1 = self.safety_critic1(state, action)
             q2 = self.safety_critic2(state, action)
-            unsafety = torch.min(q1, q2).squeeze(2)
+            unsafety = torch.min(q1, q2).squeeze(1)
 
-            # Check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state
+            # Check for actions that qualify (unsafety <= threshold), locate first (if exists)
             mask_safe = unsafety <= unsafety_threshold
-            idx_first_safe = mask_safe.int().argmax(dim=0)
 
-            # If all actions are unsafe (> threshold), argmax will return the first unsafe as mask_safe[:,,] == 0 everywhere
-            mask_all_unsafe = (~mask_safe[idx_first_safe, torch.arange(0, mask_safe.shape[1])])
+            # Search for first safe action, check if one exists
+            idx_first_safe = mask_safe.int().argmax(dim=0).item()
+            is_safe = mask_safe[idx_first_safe].item()
 
-            # We now build an idx to access the action tensor as follows:
-            # If there was at least one safe action, idx_first_safe will be the first safe action for each state
-            # If there was no safe action, we retrieve the action with minimum unsafety for each state
-            idx_0 = idx_first_safe
-            idx_0[mask_all_unsafe] = unsafety[:, mask_all_unsafe].argmin(dim=0)
-            idx_1 = torch.arange(0, mask_safe.shape[1])
-
-            # Access action tensor and return
-            return action[idx_0, idx_1, :]
+            # Return first safe or alternatively the safest one
+            if is_safe:
+                return action[idx_first_safe, :]
+            else:
+                idx_best = unsafety.argmin(dim=0).item()
+                return action[idx_best, :]
         else:
-            return self.actor_local.get_action(state).to(self.device)
+            return self.actor_local.get_action(state).squeeze(0).to(self.device)
 
     def calc_policy_loss(self, states, alpha):
         actions_pred, log_pis = self.actor_local.evaluate(states)
@@ -348,9 +343,9 @@ class CSCCQLSAC(nn.Module):
             "cql_alpha_loss": cql_alpha_loss.item(),
             "cql_alpha": cql_alpha.item()
         }
-        if self.stats.total_updates % 8 == 0:
-            self.stats.log_update_tensorboard(data)
         self.stats.total_updates += 1
+        if (self.stats.total_updates - 1) % 8 == 0:
+            self.stats.log_update_tensorboard(data)
         return data
 
     def soft_update(self, local_model , target_model):
diff --git a/src/environment.py b/src/environment.py
index dfb89717d076740298985b5eccfa8703497562c7..87d5411f6d89468b2e724684ce357890d8cf10ee 100644
--- a/src/environment.py
+++ b/src/environment.py
@@ -1,29 +1,64 @@
-import safety_gymnasium
-import gymnasium
-import numpy as np
-
-class Gymnasium2SafetyGymnasium(gymnasium.Wrapper):
-    def step(self, action):
-        state, reward, terminated, truncated, info = super().step(action)
-        if 'cost' in info:
-            cost = info['cost']
-        elif isinstance(reward, (int, float)):
-            cost = 0
-        elif isinstance(reward, np.ndarray):
-            cost = np.zeros_like(reward)
-        elif isinstance(reward, list):
-            cost = [0]*len(reward)
-        else:
-            raise NotImplementedError("reward type not recognized") # for now
-        return state, reward, cost, terminated, truncated, info
-
-def create_environment(args):
-    if args.env_id.startswith("Safety"):
-        env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
-    if args.env_id.startswith("Gymnasium_"):
-        id = args.env_id[len('Gymnasium_'):]
-        env = gymnasium.make_vec(id, num_envs=args.num_vectorized_envs, vectorization_mode="sync")
-        env = Gymnasium2SafetyGymnasium(env)
-    if args.env_id.startswith("RaceTrack"):
-        raise NotImplementedError("RaceTrack environment is not implemented yet.")
-    return env    
\ No newline at end of file
+def _parse_env_args_as_dict(env_args):
+    import ast
+    """
+    We assume the env args are represented as a dictionary string, i.e.
+    {name1:value1,name2:value2,...}
+    """
+    if not env_args: env_args = '{}'
+    d = ast.literal_eval(env_args)
+    if isinstance(d, dict):
+        return d
+    raise ValueError("Expected a dictionary string as env_args!")
+
+
+def _create_env_racetrack(args):
+    import safety_gymnasium
+    from racetrackgym.argument_parser import Racetrack_parser
+    from racetrackgym.wrapper import SafetyGymnasiumRaceTrackEnv
+    
+    # NOTE: racetrackgym has a custom parser, use it instead
+    rt_args = Racetrack_parser().parse(args.env_args.split())
+    args.env_args = rt_args.__dict__
+
+    safety_gymnasium.register(id=args.env_id, entry_point=SafetyGymnasiumRaceTrackEnv)
+    env = safety_gymnasium.make(id=args.env_id, rt_args=rt_args)
+    return env
+
+
+def _create_env_safetygym(args):
+    import safety_gymnasium
+    from safety_gymnasium.bases.base_task import LidarConf
+    
+    # NOTE: we dont have a safety gymnasium argument parser, hence expect arguments as a dictionary string
+    sg_args = _parse_env_args_as_dict(args.env_args)
+    args.env_args = sg_args.copy()
+
+    # allow changing the lidar via args
+    if 'lidar' in sg_args:
+        lidar_config = dict()
+        for key, val in sg_args['lidar'].items():
+            lidar_config[key] = val
+        
+        # Wrap init function with new default values
+        def new_init(self, **kwargs):
+            for key, val in lidar_config.items():
+                if key not in kwargs:
+                    kwargs[key] = val
+            return self.__init_original__(**kwargs)
+        
+        LidarConf.__init_original__ = LidarConf.__init__
+        LidarConf.__init__ = new_init
+        del sg_args['lidar']
+
+    env = safety_gymnasium.make(id=args.env_id)
+    return env
+
+
+def create_env(args):
+    name = args.env_id
+    if name == "RaceTrack":
+        return _create_env_racetrack(args)
+    elif name.startswith("Safety"):
+        return _create_env_safetygym(args)
+    else:
+        raise RuntimeError("Unkonwn environment: " + name)
diff --git a/src/stats.py b/src/stats.py
index 4f263bf73e5ef7fb25774d047297b8fcbfce7bf9..af63fd10948c5c89527f76fbcbb2e6f082eaa3e3 100644
--- a/src/stats.py
+++ b/src/stats.py
@@ -10,51 +10,26 @@ class Statistics:
         self.total_train_unsafe = 0
         self.total_updates = 0
 
-        # Used for calculating average unsafe of previous training episodes
-        self.train_unsafe_history = []
-        self._train_unsafe_avg = 0
-
-        # Used for averaging test results
-        self.test_steps_history = []
-        self.test_reward_history = []
-        self.test_cost_history = []
-        self.test_unsafe_history = []
-
-    @property
-    def train_unsafe_avg(self):
-        if len(self.train_unsafe_history) > 0:
-            self._train_unsafe_avg = sum(self.train_unsafe_history) / len(self.train_unsafe_history)
-            self.train_unsafe_history.clear()
-        return self._train_unsafe_avg
+        # Store avg unsafe episodes for CSC
+        self.train_unsafe_avg = 0
     
-    def log_train_history(self, unsafe: np.ndarray):
-        self.train_unsafe_history += unsafe.tolist()
+    def log_train_tensorboard(self, episode_steps, episode_return, episode_cost):
+        name_prefix = f"train/"
+        self.log_tensorboard(name_prefix + 'steps', episode_steps, self.total_train_steps)
+        self.log_tensorboard(name_prefix + 'returns', episode_return, self.total_train_steps)
+        self.log_tensorboard(name_prefix + 'costs', episode_cost, self.total_train_steps)
     
-    def log_test_history(self, steps: np.ndarray, reward: np.ndarray, cost: np.ndarray, unsafe: np.ndarray):
-        self.test_steps_history += steps.tolist()
-        self.test_reward_history += reward.tolist()
-        self.test_cost_history += cost.tolist()
-        self.test_unsafe_history += unsafe.tolist()
-    
-    def log_test_tensorboard(self, name_prefix, ticks):
-        self.log_tensorboard(name_prefix + '/avg_steps', np.array(self.test_steps_history).mean(), ticks)
-        self.log_tensorboard(name_prefix + '/avg_returns', np.array(self.test_reward_history).mean(), ticks)
-        self.log_tensorboard(name_prefix + '/avg_costs', np.array(self.test_cost_history).mean(), ticks)
-        self.log_tensorboard(name_prefix + '/avg_unsafes', np.array(self.test_unsafe_history).mean(), ticks)
-
-        self.test_steps_history.clear()
-        self.test_reward_history.clear()
-        self.test_cost_history.clear()
-        self.test_unsafe_history.clear()
+    def log_test_tensorboard(self, avg_steps, avg_return, avg_cost, avg_unsafe, shielded=True):
+        name_prefix = f"test/{'shielded' if shielded else 'unshielded'}/"
+        self.log_tensorboard(name_prefix + 'avg_steps', avg_steps, self.total_train_steps)
+        self.log_tensorboard(name_prefix + 'avg_return', avg_return, self.total_train_steps)
+        self.log_tensorboard(name_prefix + 'avg_cost', avg_cost, self.total_train_steps)
+        self.log_tensorboard(name_prefix + 'avg_unsafe', avg_unsafe, self.total_train_steps)
     
     def log_update_tensorboard(self, data:dict):
         for k, v in data.items():
             name = f"update/{k}"
             self.writer.add_scalar(name, v, self.total_updates)
 
-    def log_tensorboard(self, name: str, values: np.ndarray, ticks: np.ndarray):
-        if values.size == 1:
-            self.writer.add_scalar(name, values.item(), ticks.item())
-        else:
-            for v, t in zip(values, ticks):
-                self.writer.add_scalar(name, v, t)
\ No newline at end of file
+    def log_tensorboard(self, name, value, step):
+        self.writer.add_scalar(name, value, step)
\ No newline at end of file