diff --git a/main.py b/main.py
index 96a8b0d8ced1a42f8cdbe6cd0e2293ecfe838e60..8215e83da8493c2b95087fe51331d78bf901b928 100644
--- a/main.py
+++ b/main.py
@@ -1,198 +1,315 @@
 import argparse
-import safety_gymnasium
 import numpy as np
 import torch
 import random
 import os
 import json
 import datetime
-
 from torch.utils.tensorboard import SummaryWriter
+
+from src.stats import Statistics
 from src.buffer import ReplayBuffer
-from src.policy import CSCAgent
+from src.cql_sac.agent import CSCCQLSAC
+from src.environment import create_environment
 
 ##################
 # ARGPARSER
 ##################
+def cmd_args():
+    parser = argparse.ArgumentParser(formatter_class=lambda prog:argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=40))
+    # environment args
+    env_args = parser.add_argument_group('Environment')
+    env_args.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID",
+                        help="Set the environment")
+    env_args.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N",
+                        help="Set a cost limit/budget")
+    env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=16, metavar="N",
+                        help="Sets the number of vectorized environments")
+
+    # train and test args
+    train_test_args = parser.add_argument_group('Train and Test')
+    train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N",
+                    help="Total number of steps until training is finished")
+    train_test_args.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N",
+                        help="Number of episodes until policy optimization")
+    train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N",
+                        help="Perform evaluation after N * total_train_episodes episodes of training")
+    train_test_args.add_argument("--update_iterations", action="store", type=int, default=3, metavar="N",
+                        help="Number of updates performed after each training step")
+    train_test_args.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N",
+                        help="Number of episodes used for testing")
+    
+    # update args
+    update_args = parser.add_argument_group('Update')
+    update_args.add_argument("--batch_size", action="store", type=int, default=256, metavar="N",
+                        help="Batch size used for training")
+    update_args.add_argument("--tau", action="store", type=float, default=5e-3, metavar="N",
+                        help="Factor used in soft update of target network")
+    update_args.add_argument("--gamma", action="store", type=float, default=0.95, metavar="N",
+                        help="Discount factor for rewards")
+    update_args.add_argument("--learning_rate", action="store", type=float, default=3e-4, metavar="N",
+                        help="Learn rate for the policy and Q networks")
+    
+    # buffer args
+    buffer_args = parser.add_argument_group('Buffer')
+    buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N",
+                        help="Define the maximum capacity of the replay buffer")
+    buffer_args.add_argument("--clear_buffer", action="store_true", default=False,
+                        help="Clear Replay Buffer after every optimization step")
+    
+    # network args
+    network_args = parser.add_argument_group('Networks')
+    network_args.add_argument("--hidden_size", action="store", type=int, default=256, metavar="N",
+                        help="Hidden size of the networks")
+    
+    # cql args
+    cql_args = parser.add_argument_group('CQL')
+    cql_args.add_argument("--cql_with_lagrange", action="store_true", default=False,
+                        help="")
+    cql_args.add_argument("--cql_temp", action="store", type=float, default=1.0, metavar="N",
+                        help="")
+    cql_args.add_argument("--cql_weight", action="store", type=float, default=1.0, metavar="N",
+                        help="")
+    cql_args.add_argument("--cql_target_action_gap", action="store", type=float, default=10, metavar="N",
+                        help="")
+    
+    # csc args
+    csc_args = parser.add_argument_group('CSC')
+    csc_args.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N",
+                        help="Set the value of chi")
+    csc_args.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N",
+                        help="Set the value of delta")
+    csc_args.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N",
+                        help="Set the value of beta")
+    csc_args.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N",
+                        help="Set the value of alpha")
+    csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N",
+                        help="Set the initial value of lambda")
 
-parser = argparse.ArgumentParser()
-# environment args
-parser.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID",
-                    help="Set the environment (default: SafetyPointGoal1-v0)")
-parser.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N",
-                    help="Set a cost limit at which point an episode is considered unsafe (default: 25)")
-parser.add_argument("--enforce_cost_limit", action="store_true", default=False,
-                    help="Aborts episode if cost limit is reached (default: False)")
-
-# train args
-parser.add_argument("--train_episodes", action="store", type=int, default=5, metavar="N",
-                    help="Number of episodes until policy optimization (default: 5)")
-parser.add_argument("--train_until_test", action="store", type=int, default=3, metavar="N",
-                    help="Perform evaluation after N * train_episodes episodes of training (default: 3)")
-parser.add_argument("--update_iterations", action="store", type=int, default=1, metavar="N",
-                    help="Number of updates performed after each training step (default: 1)")
-parser.add_argument("--test_episodes", action="store", type=int, default=5, metavar="N",
-                    help="Number of episodes used for testing (default: 5)")
-parser.add_argument("--total_steps", action="store", type=int, default=2_000_000, metavar="N",
-                    help="Total number of steps until training is finished (default: 2_000_000)")
-parser.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N",
-                    help="Batch size used for training (default: 1024)")
-parser.add_argument("--tau", action="store", type=float, default=0.05, metavar="N",
-                    help="Factor used in soft update of target network (default: 0.05)")
-
-# buffer args
-parser.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N",
-                    help="Define the maximum capacity of the replay buffer (default: 50_000)")
-parser.add_argument("--clear_buffer", action="store_true", default=False,
-                    help="Clear Replay Buffer after every optimization step (default: False)")
-
-# csc args
-parser.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N",
-                    help="Maximum number of actions sampled during shielding (default: 100)")
-parser.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N",
-                    help="Maximum number of line search update iterations (default: 20)")
-parser.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N",
-                    help="Number of samples to estimate expectations (default: 20)")
-parser.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N",
-                    help="Set the value of chi (default: 0.05)")
-parser.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N",
-                    help="Set the value of delta (default: 0.01)")
-parser.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N",
-                    help="Set the value of gamma (default: 0.99)")
-parser.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N",
-                    help="Set the value of beta (default: 0.7)")
-parser.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N",
-                    help="Set the value of alpha (default: 0.5)")
-parser.add_argument("--csc_lambda", action="store", type=float, default=4e-2, metavar="N",
-                    help="Set the initial value of lambda (default: 4e-2)")
-parser.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N",
-                    help="Learn rate for the safety critic (default: 2e-4)")
-parser.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N",
-                    help="Learn rate for the value network (default: 1e-3)")
-parser.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N",
-                    help="Hidden dimension of the networks (default: 32)")
-
-# common args
-parser.add_argument("--seed", action="store", type=int, default=42, metavar="N",
-                    help="Set a custom seed for the rng (default: 42)")
-parser.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE",
-                    help="Set the device for pytorch to use (default: cuda)")
-parser.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH",
-                    help="Set the output and log directory path (default: ./runs)")
-parser.add_argument("--num_threads", action="store", type=int, default=32, metavar="N",
-                    help="Set the maximum number of threads for pytorch (default: 32)")
-args = parser.parse_args()
+    # common args
+    common_args = parser.add_argument_group('Common')
+    common_args.add_argument("--seed", action="store", type=int, default=42, metavar="N",
+                        help="Set a custom seed for the rng")
+    common_args.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE",
+                        help="Set the device for pytorch to use")
+    common_args.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH",
+                        help="Set the output and log directory path")
+    common_args.add_argument("--num_threads", action="store", type=int, default=1, metavar="N",
+                        help="Set the maximum number of threads for pytorch and numpy")
+
+    args = parser.parse_args()
+    return args
 
 ##################
 # SETUP
 ##################
 
-torch.set_num_threads(args.num_threads)
-
-random.seed(args.seed)
-np.random.seed(args.seed)
-torch.manual_seed(args.seed)
-torch.set_default_dtype(torch.float64)
-torch.set_default_device(args.device)
+def setup(args):
+    """
+    Performs setup like fixing seeds, initializing env and agent, buffer and stats.
+    """
+    torch.set_num_threads(args.num_threads)
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
 
-output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S"))
-writer = SummaryWriter(log_dir=output_dir)
-with open(os.path.join(output_dir, "config.json"), "w") as file:
-    json.dump(args.__dict__, file, indent=2)
+    output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S"))
+    writer = SummaryWriter(log_dir=output_dir)
+    with open(os.path.join(output_dir, "config.json"), "w") as file:
+        json.dump(args.__dict__, file, indent=2)
 
-env = safety_gymnasium.make(id=args.env_id, autoreset=False)
-buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
-agent = CSCAgent(env, args, writer)
+    env = create_environment(args=args)
+    buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
+    stats = Statistics(writer=writer)
+    agent = CSCCQLSAC(env=env, args=args, stats=stats)
 
-total_episodes = 0
-total_failures = 0
-total_steps = 0
+    return env, agent, buffer, stats
 
 ##################
-# USEFUL FUNCTIONS
+# EXPLORATION
 ##################
 
 @torch.no_grad
-def run_episode_single(train, shielded):
-    episode_steps = 0
-    episode_reward = 0.
-    episode_cost = 0.
+def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train=True, shielded=True):
+    # track currently running and leftover episodes
+    open_episodes = args.train_episodes if train else args.test_episodes
+    running_episodes = args.num_vectorized_envs
+
+    # initialize mask and stats per episode
+    mask = np.ones(args.num_vectorized_envs, dtype=np.bool_)
+    episode_steps = np.zeros_like(mask, dtype=np.uint64)
+    episode_reward = np.zeros_like(mask, dtype=np.float64)
+    episode_cost = np.zeros_like(mask, dtype=np.float64)
+
+    # adjust mask in case we have fewer runs than environments
+    if open_episodes < args.num_vectorized_envs:
+        mask[open_episodes:] = False
+        running_episodes = open_episodes
+    open_episodes -= running_episodes
+
+    state, info = env.reset(seed=random.randint(0, 2**31-1))
     
-    state, info = env.reset()
+    while running_episodes > 0:
+        # sample and execute actions
+        actions = agent.get_action(state, eval=not train).cpu().numpy()
+        next_state, reward, cost, terminated, truncated, info = env.step(actions)
+        done = terminated | truncated
+        not_done_masked = ((~done) & mask)
+        done_masked = done & mask
+        done_masked_count = done_masked.sum()
+
+        # increment stats
+        episode_steps[mask] += 1
+        episode_reward[mask] += reward[mask]
+        episode_cost[mask] += cost[mask]
+
+        # if any run has finished, we need to take special care
+        # 1. train: extract final_observation from info dict (single envs autoreset, no manual reset needed) and add to buffer
+        # 2. train+test: log episode stats using Statistics class
+        # 3. train+test: reset episode stats
+        # 4. train+test: adjust mask (if necessary)
+        if done_masked_count > 0:
+            if train:
+                # add experiences to buffer
+                buffer.add(
+                    state[not_done_masked], 
+                    actions[not_done_masked], 
+                    reward[not_done_masked], 
+                    cost[not_done_masked], 
+                    next_state[not_done_masked],
+                    terminated[not_done_masked]
+                )
+                buffer.add(
+                    state[done_masked], 
+                    actions[done_masked], 
+                    reward[done_masked], 
+                    cost[done_masked], 
+                    np.stack(info['final_observation'], axis=0)[done_masked],
+                    terminated[done_masked]
+                )
+
+                # record finished episodes
+                ticks = stats.total_train_steps + np.cumsum(episode_steps[done_masked], axis=0)
+                stats.log_tensorboard("train/returns", episode_reward[done_masked], ticks)
+                stats.log_tensorboard("train/costs", episode_cost[done_masked], ticks)
+                stats.log_tensorboard("train/steps", episode_steps[done_masked], ticks)
+                stats.log_tensorboard("train/unsafe", (episode_cost[done_masked] > args.cost_limit).astype(np.uint8), ticks)
+
+                stats.log_train_history((episode_cost[done_masked] > args.cost_limit).astype(np.uint8))
+
+                stats.total_train_episodes += done_masked_count
+                stats.total_train_steps += episode_steps[done_masked].sum()
+                stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum()
+            
+            else:
+                # record finished episodes
+                # stats module performs averaging over all when logging episodes
+                stats.log_test_history(
+                    steps=episode_steps[done_masked],
+                    reward=episode_reward[done_masked],
+                    cost=episode_cost[done_masked],
+                    unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
+                )
+            
+            # reset episode stats
+            state = next_state
+            episode_steps[done_masked] = 0
+            episode_reward[done_masked] = 0
+            episode_cost[done_masked] = 0
+
+            # adjust mask, running and open episodes counter
+            if open_episodes < done_masked_count:  # fewer left than just finished
+                done_masked_idxs = done_masked.nonzero()[0]
+                mask[done_masked_idxs[open_episodes:]] = False
+                running_episodes -= (done_masked_count - open_episodes)
+                open_episodes = 0
+            else:   # at least as many left than just finished
+                open_episodes -= done_masked_count
+
+        # no run has finished, just record experiences (if training)
+        else:
+            if train:
+                buffer.add(state[mask], actions[mask], reward[mask], cost[mask], next_state[mask], done[mask])
+            state = next_state
+
+    # average and log finished test episodes
+    if not train:
+        stats.log_test_tensorboard(f"test/{'shielded' if shielded else 'unshielded'}", stats.total_train_steps)
+
+def single_exploration(args, env, agent, buffer, stats, train=True, shielded=True):
+    # NOTE: Unused function
+    state, info = env.reset(seed=random.randint(0, 2**31-1))
+    episode_reward = 0
+    episode_cost = 0
+    episode_steps = 0
     done = False
 
     while not done:
-        action = agent.sample(state, shielded=shielded)
+        with torch.no_grad():
+            action = agent.get_action(np.expand_dims(state, 0), eval=not train).cpu().numpy().squeeze(0)
         next_state, reward, cost, terminated, truncated, info = env.step(action)
         done = terminated or truncated
-
-        if train:
-            buffer.add(state, action, reward, cost, next_state)
-        state = next_state
-
-        episode_steps += 1
         episode_reward += reward
         episode_cost += cost
+        episode_steps += 1
 
-        if train and args.enforce_cost_limit:   # we dont care about the cost limit while testing
-            done = done or (episode_cost >= args.cost_limit)
-    
-    return episode_steps, episode_reward, episode_cost
-
-@torch.no_grad
-def run_episode_multiple(num_episodes, train, shielded):
-    global total_episodes, total_failures, total_steps
-    avg_steps = 0
-    avg_reward = 0
-    avg_cost = 0
-    avg_failures = 0
-
-    for _ in range(num_episodes):
-        episode_steps, episode_reward, episode_cost = run_episode_single(train=train, shielded=shielded)
-        episode_failure = int(episode_cost >= args.cost_limit)
+        state = np.expand_dims(state, 0)
+        next_state = np.expand_dims(next_state, 0)
+        action = np.expand_dims(action, 0)
+        reward = np.array([reward])
+        cost = np.array([cost])
 
         if train:
-            total_episodes += 1
-            total_failures += episode_failure
-            total_steps += env.spec.max_episode_steps
+            buffer.add(state, action, reward, cost, next_state, np.array([terminated]))
+            if stats.total_train_steps >= 10_000:
+                x = agent.learn(experiences=buffer.sample(n=args.batch_size))
+                actor_loss, alpha_loss, critic1_loss, critic2_loss, cql1_scaled_loss, cql2_scaled_loss, current_alpha, cql_alpha_loss, cql_alpha = x
+                stats.writer.add_scalar("debug/actor_loss", actor_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/alpha_loss", alpha_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/critic1_loss", critic1_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/critic2_loss", critic2_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/cql1_scaled_loss", cql1_scaled_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/cql2_scaled_loss", cql2_scaled_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/current_alpha", current_alpha, stats.total_updates)
+                stats.writer.add_scalar("debug/cql_alpha_loss", cql_alpha_loss, stats.total_updates)
+                stats.writer.add_scalar("debug/cql_alpha", cql_alpha, stats.total_updates)
+                stats.total_updates += 1
 
-            writer.add_scalar("train/episode_reward", episode_reward, total_episodes)
-            writer.add_scalar("train/episode_cost", episode_cost, total_episodes)
-            writer.add_scalar("train/episode_failure", episode_failure, total_episodes)
+        state = next_state.squeeze(0)
         
+    stats.writer.add_scalar("train/returns", episode_reward, stats.total_train_steps)
+    stats.writer.add_scalar("train/costs", episode_cost, stats.total_train_steps)
+    stats.writer.add_scalar("train/steps", episode_steps, stats.total_train_steps)
 
-        avg_steps += episode_steps
-        avg_reward += episode_reward
-        avg_cost += episode_cost
-        avg_failures += episode_failure
-    
-    avg_steps /= num_episodes
-    avg_reward /= num_episodes
-    avg_cost /= num_episodes
-    avg_failures /= num_episodes
-
-    return avg_steps, avg_reward, avg_cost, avg_failures
+    stats.total_train_episodes += 1
+    stats.total_train_steps += episode_steps
+    stats.total_train_unsafe += int(episode_cost > args.cost_limit)
 
 ##################
 # MAIN LOOP
 ##################
 
-finished = False
-while not finished:
-    for _ in range(args.train_until_test):
-        if total_steps >= args.total_steps:
-            finished = True
-            break
-        avg_steps, avg_reward, avg_cost, avg_failures = run_episode_multiple(num_episodes=args.train_episodes, train=True, shielded=True)
-        for i in range(args.update_iterations):
-            agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=total_episodes)
-        if args.clear_buffer:
-            buffer.clear()
-
-    for shielded, postfix in zip([True, False], ["shielded", "unshielded"]):
-        avg_steps, avg_reward, avg_cost, avg_failures = run_episode_multiple(num_episodes=args.test_episodes, train=False, shielded=shielded)
-        writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, total_episodes)
-        writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, total_episodes)
-        writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, total_episodes)
-
-writer.flush()
\ No newline at end of file
+def main(args, env, agent, buffer, stats):
+    finished = False
+    while not finished:
+        # Training + Update Loop
+        for _ in range(args.train_until_test):
+
+            # 1. Run exploration for training
+            run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True)
+
+            # 2. Perform updates
+            for _ in range(args.update_iterations):
+                agent.learn(experiences=buffer.sample(n=args.batch_size))
+
+            # 3. After update stuff
+            if args.clear_buffer:
+                buffer.clear()
+
+        # Test loop (shielded and unshielded)
+        # for shielded in [True, False]:
+        run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=True)
+
+if __name__ == '__main__':
+    args = cmd_args()
+    main(args, *setup(args))
\ No newline at end of file
diff --git a/src/buffer.py b/src/buffer.py
index b74190d5e20ee224520cc79f373c171e133a71d4..6b72adbb5bc8264b4db5d9144b4f9bfabb3e686d 100644
--- a/src/buffer.py
+++ b/src/buffer.py
@@ -1,33 +1,68 @@
-import numpy
-import gymnasium
+import numpy as np
 
 class ReplayBuffer():
-    def __init__(self, env:gymnasium.Env, cap):
-        self._cap = cap
-        self._size = 0
-        self._ptr = 0
+    """
+    Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. Handles batched experiences.
+    """
+    def __init__(self, env, cap):
+        self._cap = max(1,cap)
+        self._size = 0              # number of experiences in the buffer
+        self._ptr = 0               # pointer to the next available slot in the buffer
+
+        self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64)
+        self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64)
+        self._rewards = np.zeros((cap, ), dtype=np.float64)
+        self._costs = np.zeros((cap, ), dtype=np.float64)
+        self._next_states = np.zeros_like(self._states)
+        self._dones = np.zeros((cap, ), dtype=np.uint8)
+
+
+    def _add(self, state, action, reward, cost, next_state, done, start, end):
+        self._states[start:end] = state
+        self._actions[start:end] = action
+        self._rewards[start:end] = reward
+        self._costs[start:end] = cost
+        self._next_states[start:end] = next_state
+        self._dones[start:end] = done
+
+
+    def add(self, state, action, reward, cost, next_state, done):
+        """
+        Adds experiences to the buffer. Assumes batched experiences.
+        """
+        n = state.shape[0]          # NOTE: n should be less than or equal to the buffer capacity
+        idx_start = self._ptr
+        idx_end = self._ptr + n
+
+        # if the buffer has capacity, add the experiences to the end of the buffer
+        if idx_end <= self._cap:
+            self._add(state, action, reward, cost, next_state, done, idx_start, idx_end)
+        
+        # if the buffer does not have capacity, add the experiences to the end of the buffer and wrap around
+        else:
+            k = self._cap - idx_start
+            idx_end = n - k
+            self._add(state[:k], action[:k], reward[:k], cost[:k], next_state[:k], done[:k], start=idx_start, end=self._cap)
+            self._add(state[k:], action[k:], reward[k:], cost[k:], next_state[k:], done[k:], start=0, end=idx_end)
+        
+        # update the buffer size and pointer 
+        self._ptr = idx_end
+        if self._size < self._cap:
+            self._size = min(self._cap, self._size + n)
 
-        self._states = numpy.zeros((cap, env.observation_space.shape[0]), dtype=numpy.float64)
-        self._actions = numpy.zeros((cap, env.action_space.shape[0]), dtype=numpy.float64)
-        self._rewards = numpy.zeros((cap, ), dtype=numpy.float64)
-        self._costs = numpy.zeros((cap, ), dtype=numpy.float64)
-        self._next_states = numpy.zeros_like(self._states)
-    
 
-    def add(self, state, action, reward, cost, next_state):
-        self._states[self._ptr] = state
-        self._actions[self._ptr] = action
-        self._rewards[self._ptr] = reward
-        self._costs[self._ptr] = cost
-        self._next_states[self._ptr] = next_state
-        self._ptr = (self._ptr+1) % self._cap
-        self._size = min(self._size+1, self._cap)
-    
     def sample(self, n):
-        idxs = numpy.random.randint(low=0, high=self._size, size=n)
+        """
+        Samples n experiences from the buffer.
+        """
+        idxs = np.random.randint(low=0, high=self._size, size=n)
         return self._states[idxs], self._actions[idxs], self._rewards[idxs], \
-            self._costs[idxs], self._next_states[idxs]
+            self._costs[idxs], self._next_states[idxs], self._dones[idxs]
     
+
     def clear(self):
+        """
+        Clears the buffer.
+        """
         self._size = 0
         self._ptr = 0
\ No newline at end of file
diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6e4cb0cef61d21173db14de1c7c44b4b89f447
--- /dev/null
+++ b/src/cql_sac/agent.py
@@ -0,0 +1,363 @@
+import torch
+import torch.optim as optim
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn.utils import clip_grad_norm_
+from .networks import Critic, Actor
+import numpy as np
+import math
+import copy
+
+
+class CSCCQLSAC(nn.Module):
+    """Interacts with and learns from the environment."""
+    
+    def __init__(self,
+                        env,
+                        args,
+                        stats
+                ):
+        """Initialize an Agent object.
+        
+        Params
+        ======
+            env : the vector environment
+            args : the argparse arguments
+        """
+        super(CSCCQLSAC, self).__init__()
+        self.stats = stats
+
+        state_size = env.observation_space.shape[-1]
+        action_size = env.action_space.shape[-1]
+        hidden_size = args.hidden_size
+        self.action_size = action_size
+
+        self.device = args.device
+        
+        self.learning_rate = args.learning_rate
+        self.gamma = args.gamma
+        self.tau = args.tau
+        self.clip_grad_param = 1
+
+        self.target_entropy = -action_size  # -dim(A)
+
+        self.log_alpha = torch.tensor([0.0], requires_grad=True, device=self.device)
+        self.alpha = self.log_alpha.exp().detach()
+        self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=self.learning_rate) 
+        
+        # CQL params
+        self.cql_with_lagrange = args.cql_with_lagrange
+        self.cql_temp = args.cql_temp
+        self.cql_weight = args.cql_weight
+        self.cql_target_action_gap = args.cql_target_action_gap
+        self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
+        self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate) 
+
+        # CSC params
+        self.csc_shield_iterations = 100
+        self.csc_alpha = args.csc_alpha
+        self.csc_beta = args.csc_beta
+        self.csc_delta = args.csc_delta
+        self.csc_chi = args.csc_chi
+        self.csc_avg_unsafe = args.csc_chi
+
+        self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device)
+        self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) 
+        
+        # Actor Network 
+        self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device)
+        self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=self.learning_rate)
+
+        # Safety Critic Network (w/ Target Network)
+        self.safety_critic1 = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.safety_critic2 = Critic(state_size, action_size, hidden_size).to(self.device)
+        
+        self.safety_critic1_target = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.safety_critic1_target.load_state_dict(self.safety_critic1.state_dict())
+
+        self.safety_critic2_target = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.safety_critic2_target.load_state_dict(self.safety_critic2.state_dict())
+
+        self.safety_critic1_optimizer = optim.Adam(self.safety_critic1.parameters(), lr=self.learning_rate)
+        self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate) 
+        
+        # Critic Network (w/ Target Network)
+        self.critic1 = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.critic2 = Critic(state_size, action_size, hidden_size).to(self.device)
+        
+        self.critic1_target = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.critic1_target.load_state_dict(self.critic1.state_dict())
+
+        self.critic2_target = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.critic2_target.load_state_dict(self.critic2.state_dict())
+
+        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=self.learning_rate)
+        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate) 
+
+    
+    def get_action(self, state, eval=False):
+        """ 
+        Returns shielded actions for given state as per current policy. 
+        
+        Note: eval is currently ignored.
+        """
+        state = torch.from_numpy(state).float().to(self.device)
+
+        batch_size = state.shape[0]
+        unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe)
+        unsafety_best = torch.full((batch_size, ), fill_value=unsafety_threshold+1).to(self.device)
+        action_best = torch.zeros(batch_size, self.action_size).to(self.device)
+
+        # Run at max 'csc_shield_iterations' iterations to find safe action
+        for _ in range(self.csc_shield_iterations):
+            # If all actions are already safe, break
+            mask_safe = unsafety_best <= unsafety_threshold
+            if mask_safe.all(): break
+
+            # Sample new actions
+            with torch.no_grad():
+                action = self.actor_local.get_action(state).to(self.device)
+            
+            # Estimate safety of new actions
+            q1 = self.safety_critic1(state, action)
+            q2 = self.safety_critic2(state, action)
+            unsafety = torch.min(q1, q2).squeeze(1)
+
+            # Update best actions if they are still unsafe and new actions are safer
+            mask_update = (~mask_safe) & (unsafety < unsafety_best)
+            unsafety_best[mask_update] = unsafety[mask_update]
+            action_best[mask_update] = action[mask_update]
+
+        return action_best
+
+    def calc_policy_loss(self, states, alpha):
+        actions_pred, log_pis = self.actor_local.evaluate(states)
+
+        q1 = self.critic1(states, actions_pred.squeeze(0))
+        q2 = self.critic2(states, actions_pred.squeeze(0))
+        min_Q = torch.min(q1,q2)
+        actor_loss = ((alpha * log_pis - min_Q )).mean()
+        return actor_loss, log_pis
+
+    def _compute_policy_values(self, obs_pi, obs_q):
+        #with torch.no_grad():
+        actions_pred, log_pis = self.actor_local.evaluate(obs_pi)
+        qs1 = self.safety_critic1(obs_q, actions_pred)
+        qs2 = self.safety_critic2(obs_q, actions_pred)
+        return qs1 - log_pis.detach(), qs2 - log_pis.detach()
+    
+    def _compute_random_values(self, obs, actions, critic):
+        random_values = critic(obs, actions)
+        random_log_probs = math.log(0.5 ** self.action_size)
+        return random_values - random_log_probs
+    
+    def learn(self, experiences):
+        """Updates actor, critics and entropy_alpha parameters using given batch of experience tuples.
+        Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state))
+        Critic_loss = MSE(Q, Q_target)
+        Actor_loss = α * log_pi(a|s) - Q(s,a)
+        where:
+            actor_target(state) -> action
+            critic_target(state, action) -> Q-value
+        Params
+        ======
+            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples 
+            gamma (float): discount factor
+        """
+        self.csc_avg_unsafe = self.stats.train_unsafe_avg
+        states, actions, rewards, costs, next_states, dones = experiences
+
+        states = torch.from_numpy(states).float().to(self.device)
+        actions = torch.from_numpy(actions).float().to(self.device)
+        rewards = torch.from_numpy(rewards).float().to(self.device).view(-1, 1)
+        costs = torch.from_numpy(costs).float().to(self.device).view(-1, 1)
+        next_states = torch.from_numpy(next_states).float().to(self.device)
+        dones = torch.from_numpy(dones).float().to(self.device).view(-1, 1)
+
+        # ---------------------------- update critic ---------------------------- #
+        # Get predicted next-state actions and Q values from target models
+        with torch.no_grad():
+            next_action, new_log_pi = self.actor_local.evaluate(next_states)
+            Q_target1_next = self.critic1_target(next_states, next_action)
+            Q_target2_next = self.critic2_target(next_states, next_action)
+            Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha * new_log_pi
+            # Compute Q targets for current states (y_i)
+            Q_targets = rewards + (self.gamma * (1 - dones) * Q_target_next)
+
+        # Compute critic loss
+        q1 = self.critic1(states, actions)
+        q2 = self.critic2(states, actions)
+
+        critic1_loss = F.mse_loss(q1, Q_targets)
+        critic2_loss = F.mse_loss(q2, Q_targets)
+
+        # Update critics
+        # critic 1
+        self.critic1_optimizer.zero_grad()
+        critic1_loss.backward(retain_graph=True)
+        clip_grad_norm_(self.critic1.parameters(), self.clip_grad_param)
+        self.critic1_optimizer.step()
+        # critic 2
+        self.critic2_optimizer.zero_grad()
+        critic2_loss.backward()
+        clip_grad_norm_(self.critic2.parameters(), self.clip_grad_param)
+        self.critic2_optimizer.step()
+
+        # ---------------------------- update safety critic ---------------------------- #
+        # Get predicted next-state actions and Q values from target models
+        with torch.no_grad():
+            next_action, new_log_pi = self.actor_local.evaluate(next_states)
+            Q_target1_next = self.safety_critic1_target(next_states, next_action)
+            Q_target2_next = self.safety_critic2_target(next_states, next_action)
+            Q_target_next = torch.min(Q_target1_next, Q_target2_next) # - self.alpha * new_log_pi
+            # Compute Q targets for current states (y_i)
+            Q_targets = costs + (self.gamma * (1 - dones) * Q_target_next)
+
+        # Compute safety_critic loss
+        q1 = self.safety_critic1(states, actions)
+        q2 = self.safety_critic2(states, actions)
+
+        safety_critic1_loss = F.mse_loss(q1, Q_targets)
+        safety_critic2_loss = F.mse_loss(q2, Q_targets)
+        
+        # CQL addon
+        num_repeat = 10
+        random_actions = torch.FloatTensor(q1.shape[0] * num_repeat, actions.shape[-1]).uniform_(-1, 1).to(self.device)
+        temp_states = states.unsqueeze(1).repeat(1, num_repeat, 1).view(states.shape[0] * num_repeat, states.shape[1])
+        temp_next_states = next_states.unsqueeze(1).repeat(1, num_repeat, 1).view(next_states.shape[0] * num_repeat, next_states.shape[1])
+        
+        current_pi_values1, current_pi_values2  = self._compute_policy_values(temp_states, temp_states)
+        next_pi_values1, next_pi_values2 = self._compute_policy_values(temp_next_states, temp_states)
+        
+        random_values1 = self._compute_random_values(temp_states, random_actions, self.safety_critic1).reshape(states.shape[0], num_repeat, 1)
+        random_values2 = self._compute_random_values(temp_states, random_actions, self.safety_critic2).reshape(states.shape[0], num_repeat, 1)
+        
+        current_pi_values1 = current_pi_values1.reshape(states.shape[0], num_repeat, 1)
+        current_pi_values2 = current_pi_values2.reshape(states.shape[0], num_repeat, 1)
+
+        next_pi_values1 = next_pi_values1.reshape(states.shape[0], num_repeat, 1)
+        next_pi_values2 = next_pi_values2.reshape(states.shape[0], num_repeat, 1)
+        
+        cat_q1 = torch.cat([random_values1, current_pi_values1, next_pi_values1], 1)
+        cat_q2 = torch.cat([random_values2, current_pi_values2, next_pi_values2], 1)
+        
+        assert cat_q1.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q1 instead has shape: {cat_q1.shape}"
+        assert cat_q2.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q2 instead has shape: {cat_q2.shape}"
+        
+        # flipped sign of cql1_scaled_loss and cql2_scaled_loss
+        cql1_scaled_loss = -(torch.logsumexp(cat_q1 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q1.mean() * self.cql_weight)
+        cql2_scaled_loss = -(torch.logsumexp(cat_q2 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q2.mean() * self.cql_weight)
+        
+        cql_alpha_loss = torch.FloatTensor([0.0])
+        cql_alpha = torch.FloatTensor([1.0])
+        if self.cql_with_lagrange:
+            cql_alpha = torch.clamp(self.cql_log_alpha.exp(), min=0.0, max=1000000.0)
+            cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.cql_target_action_gap)
+            cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.cql_target_action_gap)
+
+            self.cql_alpha_optimizer.zero_grad()
+            cql_alpha_loss = (- cql1_scaled_loss - cql2_scaled_loss) * 0.5 
+            cql_alpha_loss.backward(retain_graph=True)
+            self.cql_alpha_optimizer.step()
+        
+        total_c1_loss = safety_critic1_loss + cql1_scaled_loss
+        total_c2_loss = safety_critic2_loss + cql2_scaled_loss
+        
+        
+        # Update safety_critics
+        # safety_critic 1
+        self.safety_critic1_optimizer.zero_grad()
+        total_c1_loss.backward(retain_graph=True)
+        clip_grad_norm_(self.safety_critic1.parameters(), self.clip_grad_param)
+        self.safety_critic1_optimizer.step()
+        # safety_critic 2
+        self.safety_critic2_optimizer.zero_grad()
+        total_c2_loss.backward()
+        clip_grad_norm_(self.safety_critic2.parameters(), self.clip_grad_param)
+        self.safety_critic2_optimizer.step()
+
+        # ---------------------------- update csc lambda ---------------------------- #
+        # Estimate cost advantage
+        with torch.no_grad():
+            q1 = self.safety_critic1(states, actions)
+            q2 = self.safety_critic2(states, actions)
+            v = torch.min(q1, q2)
+
+            new_action, new_log_pi = self.actor_local.evaluate(states)
+            q1 = self.safety_critic1(states, new_action)
+            q2 = self.safety_critic2(states, new_action)
+            q = torch.min(q1, q2)
+
+            cost_advantage = (q - v).mean()
+
+        # Compute csc lambda loss
+        csc_lambda_loss = -self.csc_lambda*(self.csc_avg_unsafe + (1 / (1 - self.gamma)) * cost_advantage - self.csc_chi)
+
+        self.csc_lambda_optimizer.zero_grad()
+        csc_lambda_loss.backward()
+        self.csc_lambda_optimizer.step()
+
+        # ---------------------------- update actor ---------------------------- #
+        # Estimate reward advantage
+        q1 = self.critic1(states, actions)
+        q2 = self.critic2(states, actions)
+        v = torch.min(q1, q2).detach()
+
+        new_action, new_log_pi = self.actor_local.evaluate(states)
+        q1 = self.critic1(states, new_action)
+        q2 = self.critic2(states, new_action)
+        q = torch.min(q1, q2)
+
+        reward_advantage = q - v
+
+        # Optimize actor
+        actor_loss = ((self.alpha * new_log_pi - reward_advantage)).mean()
+        self.actor_optimizer.zero_grad()
+        actor_loss.backward()
+        self.actor_optimizer.step()
+        
+        # Compute alpha loss
+        alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean()
+        self.alpha_optimizer.zero_grad()
+        alpha_loss.backward()
+        self.alpha_optimizer.step()
+        self.alpha = self.log_alpha.exp().detach()
+
+        # ----------------------- update target networks ----------------------- #
+        self.soft_update(self.critic1, self.critic1_target)
+        self.soft_update(self.critic2, self.critic2_target)
+        self.soft_update(self.safety_critic1, self.safety_critic1_target)
+        self.soft_update(self.safety_critic2, self.safety_critic2_target)
+        
+        # ----------------------- update stats ----------------------- #
+        data = {
+            "actor_loss": actor_loss.item(),
+            "alpha_loss": alpha_loss.item(),
+            "alpha": self.alpha.item(),
+            "lambda_loss": csc_lambda_loss.item(),
+            "lambda": self.csc_lambda.item(),
+            "critic1_loss": critic1_loss.item(),
+            "critic2_loss": critic2_loss.item(),
+            "cql1_scaled_loss": cql1_scaled_loss.item(),
+            "cql2_scaled_loss": cql2_scaled_loss.item(),
+            "total_c1_loss": total_c1_loss.item(),
+            "total_c2_loss": total_c2_loss.item(),
+            "cql_alpha_loss": cql_alpha_loss.item(),
+            "cql_alpha": cql_alpha.item()
+        }
+        if self.stats.total_updates % 8 == 0:
+            self.stats.log_update_tensorboard(data)
+        self.stats.total_updates += 1
+        return data
+
+    def soft_update(self, local_model , target_model):
+        """Soft update model parameters.
+        θ_target = τ*θ_local + (1 - τ)*θ_target
+        Params
+        ======
+            local_model: PyTorch model (weights will be copied from)
+            target_model: PyTorch model (weights will be copied to)
+            tau (float): interpolation parameter 
+        """
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)
diff --git a/src/cql_sac/buffer.py b/src/cql_sac/buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcf3b751618b42d57e1e38a84762f06138d16eca
--- /dev/null
+++ b/src/cql_sac/buffer.py
@@ -0,0 +1,41 @@
+import numpy as np
+import random
+import torch
+from collections import deque, namedtuple
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, buffer_size, batch_size, device):
+        """Initialize a ReplayBuffer object.
+        Params
+        ======
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+            seed (int): random seed
+        """
+        self.device = device
+        self.memory = deque(maxlen=buffer_size)  
+        self.batch_size = batch_size
+        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+    
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = self.experience(state, action, reward, next_state, done)
+        self.memory.append(e)
+    
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device)
+        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(self.device)
+        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
+        next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device)
+        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)
+  
+        return (states, actions, rewards, next_states, dones)
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
diff --git a/src/cql_sac/networks.py b/src/cql_sac/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..72f26a3dc986994f68c3bc7a241e475b24da7e7c
--- /dev/null
+++ b/src/cql_sac/networks.py
@@ -0,0 +1,98 @@
+import torch
+import torch.nn as nn
+from torch.distributions import Normal
+import numpy as np
+import torch.nn.functional as F
+
+
+def hidden_init(layer):
+    fan_in = layer.weight.data.size()[0]
+    lim = 1. / np.sqrt(fan_in)
+    return (-lim, lim)
+
+class Actor(nn.Module):
+    """Actor (Policy) Model."""
+
+    def __init__(self, state_size, action_size, hidden_size=32, log_std_min=-20, log_std_max=2):
+        """Initialize parameters and build model.
+        Params
+        ======
+            state_size (int): Dimension of each state
+            action_size (int): Dimension of each action
+            hidden_size (int): Number of nodes in each hidden layer
+        """
+        super(Actor, self).__init__()
+        self.log_std_min = log_std_min
+        self.log_std_max = log_std_max
+        
+        self.fc1 = nn.Linear(state_size, hidden_size)
+        self.fc2 = nn.Linear(hidden_size, hidden_size)
+
+        self.mu = nn.Linear(hidden_size, action_size)
+        self.log_std_linear = nn.Linear(hidden_size, action_size)
+
+    def forward(self, state):
+        x = F.relu(self.fc1(state))
+        x = F.relu(self.fc2(x))
+        mu = self.mu(x)
+
+        log_std = self.log_std_linear(x)
+        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
+        return mu, log_std
+    
+    def evaluate(self, state, epsilon=1e-6):
+        mu, log_std = self.forward(state)
+        std = log_std.exp()
+        dist = Normal(mu, std)
+        e = dist.rsample().to(state.device)
+        action = torch.tanh(e)
+        log_prob = (dist.log_prob(e) - torch.log(1 - action.pow(2) + epsilon)).sum(1, keepdim=True)
+
+        return action, log_prob
+        
+    
+    def get_action(self, state):
+        """
+        returns the action based on a squashed gaussian policy. That means the samples are obtained according to:
+        a(s,e)= tanh(mu(s)+sigma(s)+e)
+        """
+        mu, log_std = self.forward(state)
+        std = log_std.exp()
+        dist = Normal(mu, std)
+        e = dist.rsample().to(state.device)
+        action = torch.tanh(e)
+        return action.detach().cpu()
+    
+    def get_det_action(self, state):
+        mu, log_std = self.forward(state)
+        return torch.tanh(mu).detach().cpu()
+
+
+class Critic(nn.Module):
+    """Critic (Value) Model."""
+
+    def __init__(self, state_size, action_size, hidden_size=32):
+        """Initialize parameters and build model.
+        Params
+        ======
+            state_size (int): Dimension of each state
+            action_size (int): Dimension of each action
+            hidden_size (int): Number of nodes in the network layers
+        """
+        super(Critic, self).__init__()
+        self.fc1 = nn.Linear(state_size+action_size, hidden_size)
+        self.fc2 = nn.Linear(hidden_size, hidden_size)
+        self.fc3 = nn.Linear(hidden_size, 1)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
+        self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
+        self.fc3.weight.data.uniform_(-3e-3, 3e-3)
+
+    def forward(self, state, action):
+        """Build a critic (value) network that maps (state, action) pairs -> Q-values."""
+        x = torch.cat((state, action), dim=-1)
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        return self.fc3(x)
\ No newline at end of file
diff --git a/src/cql_sac/utils.py b/src/cql_sac/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f3a87ce2353474c35e3b42b00934f514400d0bf
--- /dev/null
+++ b/src/cql_sac/utils.py
@@ -0,0 +1,43 @@
+import torch
+import numpy as np
+
+def save(args, save_name, model, wandb, ep=None):
+    import os
+    save_dir = './trained_models/' 
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    if not ep == None:
+        torch.save(model.state_dict(), save_dir + args.run_name + save_name + str(ep) + ".pth")
+        wandb.save(save_dir + args.run_name + save_name + str(ep) + ".pth")
+    else:
+        torch.save(model.state_dict(), save_dir + args.run_name + save_name + ".pth")
+        wandb.save(save_dir + args.run_name + save_name + ".pth")
+
+def collect_random(env, dataset, num_samples=200):
+    state = env.reset()
+    for _ in range(num_samples):
+        action = env.action_space.sample()
+        next_state, reward, done, _ = env.step(action)
+        dataset.add(state, action, reward, next_state, done)
+        state = next_state
+        if done:
+            state = env.reset()
+
+def evaluate(env, policy, eval_runs=5): 
+    """
+    Makes an evaluation run with the current policy
+    """
+    reward_batch = []
+    for i in range(eval_runs):
+        state = env.reset()
+
+        rewards = 0
+        while True:
+            action = policy.get_action(state, eval=True)
+
+            state, reward, done, _ = env.step(action)
+            rewards += reward
+            if done:
+                break
+        reward_batch.append(rewards)
+    return np.mean(reward_batch)
\ No newline at end of file
diff --git a/src/environment.py b/src/environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfb89717d076740298985b5eccfa8703497562c7
--- /dev/null
+++ b/src/environment.py
@@ -0,0 +1,29 @@
+import safety_gymnasium
+import gymnasium
+import numpy as np
+
+class Gymnasium2SafetyGymnasium(gymnasium.Wrapper):
+    def step(self, action):
+        state, reward, terminated, truncated, info = super().step(action)
+        if 'cost' in info:
+            cost = info['cost']
+        elif isinstance(reward, (int, float)):
+            cost = 0
+        elif isinstance(reward, np.ndarray):
+            cost = np.zeros_like(reward)
+        elif isinstance(reward, list):
+            cost = [0]*len(reward)
+        else:
+            raise NotImplementedError("reward type not recognized") # for now
+        return state, reward, cost, terminated, truncated, info
+
+def create_environment(args):
+    if args.env_id.startswith("Safety"):
+        env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
+    if args.env_id.startswith("Gymnasium_"):
+        id = args.env_id[len('Gymnasium_'):]
+        env = gymnasium.make_vec(id, num_envs=args.num_vectorized_envs, vectorization_mode="sync")
+        env = Gymnasium2SafetyGymnasium(env)
+    if args.env_id.startswith("RaceTrack"):
+        raise NotImplementedError("RaceTrack environment is not implemented yet.")
+    return env    
\ No newline at end of file
diff --git a/src/networks.py b/src/networks.py
index 85f6bf9c876b530fc43fcfd59125a67faa166523..5d73c344d59e2718f98ddf48e8d5b09815dacc41 100644
--- a/src/networks.py
+++ b/src/networks.py
@@ -3,10 +3,6 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.distributions import Normal
 
-"""
-Copied from spice project.
-"""
-
 LOG_SIG_MAX = 2
 LOG_SIG_MIN = -20
 epsilon = 1e-6
@@ -36,24 +32,23 @@ class ValueNetwork(nn.Module):
 
 
 class QNetwork(nn.Module):
-    def __init__(self, num_inputs, num_actions, hidden_dim):
+    def __init__(self, num_inputs, num_actions, hidden_dim, sigmoid_activation=False):
         super().__init__()
 
-        # Q1 architecture
         self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
         self.linear2 = nn.Linear(hidden_dim, hidden_dim)
         self.linear3 = nn.Linear(hidden_dim, 1)
+        self.last_activation = F.sigmoid if sigmoid_activation else nn.Identity()
 
         self.apply(weights_init_)
 
     def forward(self, state, action):
-        xu = torch.cat([state, action], 1)
-
-        x1 = F.relu(self.linear1(xu))
-        x1 = F.relu(self.linear2(x1))
-        x1 = self.linear3(x1)
-
-        return x1#, x2
+        x = torch.cat([state, action], -1)
+        x = F.relu(self.linear1(x))
+        x = F.relu(self.linear2(x))
+        x = self.linear3(x)
+        x = self.last_activation(x)
+        return x
 
 
 class GaussianPolicy(nn.Module):
@@ -74,12 +69,9 @@ class GaussianPolicy(nn.Module):
             self.action_bias = torch.tensor(0.)
         else:
             self.action_scale = torch.FloatTensor(
-                (action_space.high - action_space.low) / 2.)
+                (action_space.high[0] - action_space.low[0]) / 2.)
             self.action_bias = torch.FloatTensor(
-                (action_space.high + action_space.low) / 2.)
-
-    def __call__(self, state, action):
-        return self.log_prob(state, action)
+                (action_space.high[0] + action_space.low[0]) / 2.)
 
     def forward(self, state):
         x = F.relu(self.linear1(state))
@@ -90,28 +82,42 @@ class GaussianPolicy(nn.Module):
         return mean, log_std
     
     def distribution(self, state):
-        mean, log_std = self.forward(state)
-        std = log_std.exp()
-        normal = Normal(mean, std, validate_args=False)
+        try:
+            mean, log_std = self.forward(state)
+            std = log_std.exp()
+            normal = Normal(mean, std)
+        except ValueError:
+            print("state:", state)
+            print(mean)
+            print(log_std)
+            print(std)
+
+            print("linear1:", self.linear1)
+            print(self.linear2)
+            print(self.mean_linear)
+            print(self.log_std_linear)
+
+            print("bias:", self.action_bias)
+            print(self.action_scale)
+            exit(0)
+            
         return normal
     
     def log_prob(self, state, action):
         dist = self.distribution(state)
-        return dist.log_prob(action)
+        y_t = (action - self.action_bias) / self.action_scale
+        x_t = torch.atanh(y_t)
+        log_prob = dist.log_prob(x_t)
+        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
+        log_prob = log_prob.sum(axis=-1, keepdim=True)
+        return log_prob
 
-    def sample(self, state):
-        mean, log_std = self.forward(state)
-        std = log_std.exp()
-        normal = Normal(mean, std)
-        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
+    def sample(self, state, num_samples=1):
+        normal = self.distribution(state)
+        x_t = normal.rsample((num_samples, ))  # for reparameterization trick (mean + std * N(0,1))
         y_t = torch.tanh(x_t)
         action = y_t * self.action_scale + self.action_bias
-        log_prob = normal.log_prob(x_t)
-        # Enforcing Action Bound
-        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
-        log_prob = log_prob.sum(1, keepdim=True)
-        mean = torch.tanh(mean) * self.action_scale + self.action_bias
-        return action, log_prob, mean
+        return action
 
     def to(self, device):
         self.action_scale = self.action_scale.to(device)
diff --git a/src/policy.py b/src/policy.py
index d15383db54e784be8a2ca5a981c1b31ace07f262..ae73543e1805f2c0bc48ad6f9196468c0227134b 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -1,11 +1,13 @@
 import torch
 from torch.distributions import kl_divergence
-from torch.func import functional_call, vmap, grad
-from .networks import GaussianPolicy, QNetwork, ValueNetwork
+from torch.nn.functional import mse_loss
+from src.networks import GaussianPolicy, QNetwork, ValueNetwork
+from src.helper import soft_update, hard_update, apply_gradient
 
 class CSCAgent():
-    def __init__(self, env, args, writer) -> None:
-        self._writer = writer
+    def __init__(self, env, args, buffer, stats) -> None:
+        self._buffer = buffer
+        self._stats = stats
 
         self._device = args.device
         self._shield_iterations = args.shield_iterations
@@ -16,58 +18,44 @@ class CSCAgent():
         self._gamma = args.csc_gamma
         self._delta = args.csc_delta
         self._chi = args.csc_chi
-        self._avg_failures = self._chi
+        self._avg_train_unsafe = self._chi
         self._batch_size = args.batch_size
-        self._lambda = args.csc_lambda
         self._expectation_estimation_samples = args.expectation_estimation_samples
         self._tau = args.tau
-
-        num_inputs = env.observation_space.shape[0]
-        num_actions = env.action_space.shape[0]
-        hidden_dim = args.hidden_dim
-        self._policy = GaussianPolicy(num_inputs, num_actions, hidden_dim, env.action_space).to(self._device)
-        self._safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device)
-        self._value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device)
-
-        self._target_safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device)
-        self._target_value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device)
-        self.soft_update(self._target_safety_critic, self._safety_critic, tau=1)
-        self.soft_update(self._target_value_network, self._value_network, tau=1)
+        self._hidden_dim = args.hidden_dim
+        self._lambda_lr = args.csc_lambda_lr
+        self._shielded_action_sampling = args.shielded_action_sampling
+
+        num_inputs = env.observation_space.shape[-1]
+        num_actions = env.action_space.shape[-1]
+        self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
+        self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
+        self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
+        self._lambda = torch.nn.Parameter(torch.tensor(args.csc_lambda, requires_grad=True, device=self._device))
+
+        self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
+        self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
+        self._target_value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
+        hard_update(self._policy_old, self._policy)
+        hard_update(self._target_safety_critic, self._safety_critic)
+        hard_update(self._target_value_network, self._value_network)
 
         self._optim_safety_critic = torch.optim.Adam(self._safety_critic.parameters(), lr=args.csc_safety_critic_lr)
         self._optim_value_network = torch.optim.Adam(self._value_network.parameters(), lr=args.csc_value_network_lr)
 
-        self._policy.eval()
-        self._safety_critic.eval()
-    
-    @staticmethod
-    @torch.no_grad
-    def soft_update(target, source, tau):
-        for tparam, sparam in zip(target.parameters(), source.parameters()):
-            tparam.copy_((1 - tau) * tparam.data + tau * sparam.data)
-
-    @torch.no_grad
-    def _cost_advantage(self, states, actions):
-        cost_actions = self._safety_critic.forward(states, actions)
-        cost_states = torch.zeros_like(cost_actions)
-        for _ in range(self._expectation_estimation_samples):
-            a = self._policy.sample(states)[0].detach()
-            cost_states += self._target_safety_critic.forward(states, a)
-        cost_states /= self._expectation_estimation_samples
-        return cost_actions - cost_states
-    
-
-    def _reward_diff(self, states, rewards, next_states):
+        # TODO
+        self._optim_policy = torch.optim.Adam(self._policy.parameters(), lr=args.csc_safety_critic_lr)
+
+    def _update_value_network(self, states, rewards, next_states, dones):
+        """
+        Updates the reward value network using MSE on a batch of experiences.
+        """
         value_states = self._value_network.forward(states)
         with torch.no_grad():
             value_next_states = self._target_value_network.forward(next_states)
+            value_next_states *= (1 - dones.view((-1, 1)))
             value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
-        return value_target - value_states
-
-
-    def _update_value_network(self, states, rewards, next_states):
-        value_diff = self._reward_diff(states, rewards, next_states)
-        loss = 1/2 * torch.square(value_diff).mean()
+        loss = mse_loss(value_states, value_target, reduction='mean')
 
         self._optim_value_network.zero_grad()
         loss.backward()
@@ -75,160 +63,238 @@ class CSCAgent():
         return loss.item()
 
 
-    def _update_policy(self, states, actions, rewards, next_states):
-        @torch.no_grad
-        def estimate_ahat(states, actions, rewards, next_states):
-            lambda_prime = self._lambda / (1 - self._gamma)
-            adv_cost = self._cost_advantage(states, actions)
-            reward_diff = self._reward_diff(states, rewards, next_states)
-            return reward_diff - lambda_prime * adv_cost
-        
-        @torch.no_grad
-        def apply_gradient(gradient):
-            l = 0
-            for name, param in self._policy.named_parameters():
-                if not param.requires_grad: continue
-                n = param.numel()
-                param.copy_(param.data + gradient[l:l+n].reshape(param.shape))
-                l += n
-        
-        def log_prob_grad(states, actions):
-            # https://pytorch.org/tutorials/intermediate/per_sample_grads.html
-            params = {k: v.detach() for k, v in self._policy.named_parameters()}
-            buffers = {k: v.detach() for k, v in self._policy.named_buffers()}
-            def compute_log_prob(params, buffers, s, a):
-                s = s.unsqueeze(0)
-                a = a.unsqueeze(0)
-                lp = functional_call(self._policy, (params, buffers), (s,a))
-                return lp.sum()
-            
-            ft_compute_grad = grad(compute_log_prob)
-            ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
-            ft_per_sample_grads = ft_compute_sample_grad(params, buffers, states, actions)
-
-            grads = []
-            for key, param in self._policy.named_parameters():
-                if not param.requires_grad: continue
-                g = ft_per_sample_grads[key].flatten(start_dim=1)
-                grads.append(g)
-            return torch.hstack(grads)
-
-        @torch.no_grad
-        def fisher_matrix(glog_prob):
-            b, n = glog_prob.shape
-            s,i = 32,0  # batch size, index
-            fisher = torch.zeros((n,n))
-            while i < b:
-                g = glog_prob[i:i+s, ...]
-                fisher += (torch.func.vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b
-                i += s
-            return fisher
-
-        glog_prob = log_prob_grad(states, actions).detach()
-        fisher = fisher_matrix(glog_prob)
-        ahat = estimate_ahat(states, actions, rewards, next_states)
-        gJ = (glog_prob * ahat).mean(dim=0).view((-1, 1))
-
-        beta_term = torch.sqrt((2 * self._delta) / (gJ.T @ fisher @ gJ))
-        try:
-            gradient_term = torch.linalg.solve(fisher, gJ)
-        except RuntimeError:
-            # https://pytorch.org/docs/stable/generated/torch.linalg.pinv.html
-            # NOTE: requires full rank on cuda
-            # gradient_term = torch.linalg.lstsq(fisher.cpu(), gJ.cpu()).solution.to(self._device)  # very slow
-            gradient_term = torch.linalg.pinv(fisher) @ gJ
-
-        del glog_prob, ahat, fisher, gJ
+    def _update_safety_critic(self, states, actions, costs, next_states, dones):
+        """
+        Updates the safety critic network using the CQL objective (Equation 2) from the CSC paper.
+        """
+        # states, action from old policy (from buffer)
+        safety_sa_env = self._safety_critic.forward(states, actions)
 
+        # states from old policy (from buffer), actions from current policy
         with torch.no_grad():
-            dist_old = self._policy.distribution(states)
-            beta_j = self._beta
-            for j in range(self._line_search_iterations):
-                beta_j = beta_j * (1 - beta_j)**j
-                gradient = beta_j * beta_term * gradient_term
-
-                apply_gradient(gradient)
-                dist_new = self._policy.distribution(states)
-                kl_div = kl_divergence(dist_old, dist_new).mean()
-                if kl_div <= self._delta:
-                    return j
-                else:
-                    # NOTE: probably better to save the weights and restore them
-                    apply_gradient(-gradient)
-        return torch.nan
-
-
-    def _update_safety_critic(self, states, actions, costs, next_states):
-        safety_sa_env = self._safety_critic.forward(states, actions)
+            actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling)
+        safety_s_env_a_p = self._safety_critic.forward(states, actions_current_policy)
 
-        a = self._policy.sample(states)[0].detach()
-        safety_s_env_a_p = self._safety_critic.forward(states, a)
+        # first loss term (steers critic towards overapproximation)
+        loss1 = (-safety_s_env_a_p.mean() + safety_sa_env.mean())
 
+        # bellman operator
         with torch.no_grad():
             safety_next_state = torch.zeros_like(safety_sa_env)
             for _ in range(self._expectation_estimation_samples):
-                a = self._policy.sample(next_states)[0].detach()
-                safety_next_state += self._target_safety_critic.forward(next_states, a)
+                actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling)
+                safety_next_state += self._target_safety_critic.forward(next_states, actions_current_policy)
             safety_next_state /= self._expectation_estimation_samples
-            safety_next_state = costs.view((-1, 1)) + self._gamma * safety_next_state
-        safety_sasc_env = torch.square(safety_sa_env - safety_next_state)
+            safety_next_state *= (1 - dones.view((-1, 1)))
+            safety_target = costs.view((-1, 1)) + self._gamma * safety_next_state
+        
+        # second loss term (mse loss)
+        loss2 = mse_loss(safety_sa_env, safety_target, reduction='mean')
 
-        loss = self._alpha * (safety_sa_env.mean() - safety_s_env_a_p.mean()) + 1/2 * safety_sasc_env.mean()
+        # overall weighted loss as sum of loss1 and loss2
+        loss = self._alpha * loss1 + 1/2 * loss2
 
         self._optim_safety_critic.zero_grad()
         loss.backward()
         self._optim_safety_critic.step()
         return loss.item()
 
+    def _tmp(self, states, actions, rewards, next_states, dones):
+        # importance sampling weights
+        action_log_prob = self._policy.log_prob(states, actions)
+        action_log_prob_old = self._policy_old.log_prob(states, actions).detach()
+        # avoid underflow by subtracting the max
+        max_log_prob = torch.max(torch.stack((action_log_prob, action_log_prob_old)))
+        action_log_prob -= max_log_prob
+        action_log_prob_old -= max_log_prob
+        # calculate importance ratio
+        weighting:torch.Tensor = (action_log_prob - action_log_prob_old).exp()
+
+        # reward advantage
+        with torch.no_grad():
+            value_states = self._target_value_network.forward(states)
+            value_next_states = self._target_value_network.forward(next_states)
+            value_next_states *= (1 - dones.view((-1, 1)))
+            value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
+            reward_advantage = value_target - value_states
+        objective = (weighting * reward_advantage).mean()
+        # objective = (weighting * value_states).mean()
+
+        #objective *= -1
+        #self._optim_policy.zero_grad()
+        #objective.backward()
+        self._policy.zero_grad()
+        objective.backward()
+        if False:
+            print("alp:", action_log_prob)
+            print(action_log_prob_old)
+            print("weighting:", weighting)
+            print(objective)
+            print("lin1grad:", self._policy.linear1.weight.grad)
+            print(self._policy.linear2.weight.grad)
+            print("meanlingrad:", self._policy.mean_linear.weight.grad)
+            print(self._policy.log_std_linear.weight.grad)
+        apply_gradient(self._policy, lr=self._lambda_lr)
+        #self._optim_policy.step()
+
+        return objective.item()
+
+
+    def _primal_dual_gradient(self, states, actions, rewards, next_states, dones):
+        """
+        Updates the policy and lambda parameter. Calculates the objective from equation 45 from the CSC paper.
+        """
+        return self._tmp(states, actions, rewards, next_states, dones)
+        # importance sampling factor
+        action_log_prob = self._policy.log_prob(states, actions)
+        action_log_prob_old = self._policy_old.log_prob(states, actions).detach()
+        weighting = (action_log_prob - action_log_prob_old).exp()
+
+        # reward advantage
+        with torch.no_grad():
+            value_states = self._target_value_network.forward(states)
+            value_next_states = self._target_value_network.forward(next_states)
+            value_next_states *= (1 - dones.view((-1, 1)))
+            value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
+            reward_advantage = value_target - value_states
+
+        # lambda prime
+        lambda_prime = self._lambda / (1 - self._gamma)
+
+        # cost advantage
+        with torch.no_grad():
+            cost_actions = self._target_safety_critic.forward(states, actions)
+            cost_states = torch.zeros_like(cost_actions)
+            for _ in range(self._expectation_estimation_samples):
+                actions_policy_old = self._sample_policy_old(states, shielded=self._shielded_action_sampling)
+                cost_states += self._target_safety_critic.forward(states, actions_policy_old)
+            cost_states /= self._expectation_estimation_samples
+            cost_advantage = cost_actions - cost_states
+        
+        # chi prime term (eq. 44)
+        chi_prime_term = self._lambda * (self._chi - self._avg_train_unsafe)
+
+        # overall objective
+        objective = (weighting * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term
+
+        # reset gradients
+        self._policy.zero_grad()
+        self._lambda.grad = None
 
-    @torch.no_grad
-    def _update_lambda(self, states, actions):
-        gamma_inv = 1 / (1 - self._gamma)
-        adv = self._cost_advantage(states, actions).mean()
-        chi_prime = self._chi - self._avg_failures
-        gradient = (gamma_inv * adv - chi_prime).item()
-        self._lambda -= gradient
-        return gradient
+        # calculate gradients
+        objective.backward()
 
+        # apply policy gradient
+        apply_gradient(self._policy, lr=1)                  # gradient ascent
 
-    def update(self, buffer, avg_failures, total_episodes):
-        self._avg_failures = avg_failures
-        self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures)
+        # obtain action distributions
+        action_dist_old = self._policy_old.distribution(states)
+        action_dist_new = self._policy.distribution(states)
+        kl_div = kl_divergence(action_dist_old, action_dist_new).mean()
 
-        states, actions, rewards, costs, next_states = buffer.sample(self._batch_size)
-        states = torch.tensor(states)
-        actions = torch.tensor(actions)
-        rewards = torch.tensor(rewards)
-        costs = torch.tensor(costs)
-        next_states = torch.tensor(next_states)
+        # revert update
+        apply_gradient(self._policy, lr=-1)                 # gradient descent back
 
-        piter = self._update_policy(states, actions, rewards, next_states)
-        vloss = self._update_value_network(states, rewards, next_states)
-        sloss = self._update_safety_critic(states, actions, costs, next_states)
-        lgradient = self._update_lambda(states, actions)
+        # calculate beta
+        beta_term = torch.sqrt(self._delta / kl_div)
+        beta_j = self._beta
+
+        # line search
+        for j in range(self._line_search_iterations):
+            beta_j = beta_j * (1 - beta_j)**j
+            beta = beta_j * beta_term
+
+            apply_gradient(self._policy, lr=beta)         # gradient ascent
+            action_dist_new = self._policy.distribution(states)
+            kl_div = kl_divergence(action_dist_old, action_dist_new).mean()
+
+            if kl_div > self._delta:
+                apply_gradient(self._policy, lr=-beta)     # gradient descent back
+            else:
+                break
 
-        self.soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau)
-        self.soft_update(self._target_value_network, self._value_network, tau=self._tau)
+        # update lambda
+        with torch.no_grad():
+            self._lambda -= self._lambda_lr * self._lambda.grad   # gradient descent
+            self._lambda.clamp_min_(min=0)
+        
+        return j
 
-        self._writer.add_scalar(f"agent/policy_iterations", piter, total_episodes)
-        self._writer.add_scalar(f"agent/value_loss", round(vloss,4), total_episodes)
-        self._writer.add_scalar(f"agent/safety_loss", round(sloss,4), total_episodes)
-        self._writer.add_scalar(f"agent/lambda_gradient", round(lgradient,4), total_episodes)
 
+    def update(self):
+        """
+        Performs one iteration of updates. Updates the value network, policy network, lambda parameter and safety critic.
+        """
+        self._avg_train_unsafe = self._stats.avg_train_unsafe
+        
+        states, actions, rewards, costs, next_states, dones = self._buffer.sample(self._batch_size)
+        states = torch.tensor(states, device=self._device)
+        actions = torch.tensor(actions, device=self._device)
+        rewards = torch.tensor(rewards, device=self._device)
+        costs = torch.tensor(costs, device=self._device)
+        next_states = torch.tensor(next_states, device=self._device)
+        dones = torch.tensor(dones, device=self._device)
+
+        vloss = self._update_value_network(states, rewards, next_states, dones)
+        soft_update(self._target_value_network, self._value_network, tau=self._tau)
+        piter = self._primal_dual_gradient(states, actions, rewards, next_states, dones)
+        sloss = self._update_safety_critic(states, actions, costs, next_states, dones)
+        soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau)
+
+        self._stats.writer.add_scalar("debug/vloss", vloss, self._stats.total_train_episodes)
+        self._stats.writer.add_scalar("debug/sloss", sloss, self._stats.total_train_episodes)
+        self._stats.writer.add_scalar("debug/piter", piter, self._stats.total_train_episodes)
+        self._stats.writer.add_scalar("debug/lambda", self._lambda, self._stats.total_train_episodes)
+
+
+    def after_updates(self):
+        self._unsafety_threshold = (1 - self._gamma) * (self._chi - self._avg_train_unsafe)
+        hard_update(self._policy_old, self._policy)
+
+
+    def _sample_policy_old(self, state, shielded):
+        """
+        Allows action sampling from policy_old.
+        """
+        tmp = self._policy
+        self._policy = self._policy_old
+        actions = self.sample(state, shielded=shielded)
+        self._policy = tmp
+        return actions
 
     def sample(self, state, shielded=True):
-        state = torch.tensor(state).unsqueeze(0)
+        """
+        Samples and returns one action for every state. If shielded is true, performs rejection sampling according to the CSC paper using the safety critic.
+        Instead of using a loop, we sample all actions simultaneously and pick:
+        - the first with an estimated unsafety level <= self._unsafety_threshold (epsilon in the CSC paper)
+        - or else, the one that achieves maximum safety, i.e., lowest estimated unsafety
+        """
+        state = torch.tensor(state, device=self._device, dtype=torch.float64)
         if shielded:
-            state = state.expand((self._shield_iterations, -1))
-            action = self._policy.sample(state)[0].detach()
-            unsafety = self._safety_critic.forward(state, action).squeeze()
-            mask = (unsafety <= self._unsafety_threshold).nonzero().flatten()
-            if mask.numel() > 0:
-                idx = mask[0]
-            else:
-                idx = torch.argmin(unsafety)
-                if idx.numel() > 1: idx = idx[0]
-            return action[idx].cpu().numpy()
+            # sample all actions, expand/copy state, estimate safety for all actions at the same time
+            actions = self._policy.sample(state, num_samples=self._shield_iterations) # shape: (shield_iterations, batch_size, action_size)
+            state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1)) # shape: (shield_iterations, batch_size, state_size)
+            unsafety = self._safety_critic.forward(state, actions) # shape: (shield_iterations, batch_size, 1)
+            unsafety = unsafety.squeeze(2) # shape: (shield_iterations, batch_size)
+
+            # check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state
+            mask = unsafety <= self._unsafety_threshold # shape: (shield_iterations, batch_size)
+            row_idx = mask.int().argmax(dim=0) # idx of first "potentially" safe actions
+            batch_idx = torch.arange(0, mask.shape[1])
+
+            # retrieve estimated safety of said action and check if it is indeed <= threshold
+            # if for a state all actions are unsafe (> threshold), argmax will return the first unsafe as mask[state] == 0 everywhere
+            unsafety_chosen = unsafety[row_idx, batch_idx]
+            all_unsafe = unsafety_chosen > self._unsafety_threshold
+
+            # for such states retrieve the action with a minimum level of unsafety
+            row_idx[all_unsafe] = unsafety[:, all_unsafe].argmin(dim=0)
+
+            # sample and return actions according to idxs
+            result = actions[row_idx, batch_idx, :]
+            return result
+
         else:
-            action = self._policy.sample(state)[0].detach().squeeze().cpu().numpy()
-            return action
\ No newline at end of file
+            # simply sample for every state one action
+            actions = self._policy.sample(state) # shape: (num_samples, batch_size, action_size)
+            return actions.squeeze(0)
\ No newline at end of file
diff --git a/src/stats.py b/src/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f263bf73e5ef7fb25774d047297b8fcbfce7bf9
--- /dev/null
+++ b/src/stats.py
@@ -0,0 +1,60 @@
+import numpy as np
+
+class Statistics:
+    def __init__(self, writer):
+        self.writer = writer
+
+        # Total training statistics
+        self.total_train_episodes = 0
+        self.total_train_steps = 0
+        self.total_train_unsafe = 0
+        self.total_updates = 0
+
+        # Used for calculating average unsafe of previous training episodes
+        self.train_unsafe_history = []
+        self._train_unsafe_avg = 0
+
+        # Used for averaging test results
+        self.test_steps_history = []
+        self.test_reward_history = []
+        self.test_cost_history = []
+        self.test_unsafe_history = []
+
+    @property
+    def train_unsafe_avg(self):
+        if len(self.train_unsafe_history) > 0:
+            self._train_unsafe_avg = sum(self.train_unsafe_history) / len(self.train_unsafe_history)
+            self.train_unsafe_history.clear()
+        return self._train_unsafe_avg
+    
+    def log_train_history(self, unsafe: np.ndarray):
+        self.train_unsafe_history += unsafe.tolist()
+    
+    def log_test_history(self, steps: np.ndarray, reward: np.ndarray, cost: np.ndarray, unsafe: np.ndarray):
+        self.test_steps_history += steps.tolist()
+        self.test_reward_history += reward.tolist()
+        self.test_cost_history += cost.tolist()
+        self.test_unsafe_history += unsafe.tolist()
+    
+    def log_test_tensorboard(self, name_prefix, ticks):
+        self.log_tensorboard(name_prefix + '/avg_steps', np.array(self.test_steps_history).mean(), ticks)
+        self.log_tensorboard(name_prefix + '/avg_returns', np.array(self.test_reward_history).mean(), ticks)
+        self.log_tensorboard(name_prefix + '/avg_costs', np.array(self.test_cost_history).mean(), ticks)
+        self.log_tensorboard(name_prefix + '/avg_unsafes', np.array(self.test_unsafe_history).mean(), ticks)
+
+        self.test_steps_history.clear()
+        self.test_reward_history.clear()
+        self.test_cost_history.clear()
+        self.test_unsafe_history.clear()
+    
+    def log_update_tensorboard(self, data:dict):
+        for k, v in data.items():
+            name = f"update/{k}"
+            self.writer.add_scalar(name, v, self.total_updates)
+
+    def log_tensorboard(self, name: str, values: np.ndarray, ticks: np.ndarray):
+        if values.size == 1:
+            self.writer.add_scalar(name, values.item(), ticks.item())
+        else:
+            for v, t in zip(values, ticks):
+                self.writer.add_scalar(name, v, t)
\ No newline at end of file