diff --git a/main.py b/main.py
index 7eef356d171125bab10280363f2ecdecc547f607..d9d5b7c0eb05053b4bbe887549ca9274f02fec99 100644
--- a/main.py
+++ b/main.py
@@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter
 
 from src.stats import Statistics
 from src.buffer import ReplayBuffer
-from src.policy import CSCAgent
+from src.cql_sac.agent import CQLSAC
 
 ##################
 # ARGPARSER
@@ -28,6 +28,8 @@ def cmd_args():
 
     # train and test args
     train_test_args = parser.add_argument_group('Train and Test')
+    train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N",
+                    help="Total number of steps until training is finished")
     train_test_args.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N",
                         help="Number of episodes until policy optimization")
     train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N",
@@ -36,12 +38,17 @@ def cmd_args():
                         help="Number of updates performed after each training step")
     train_test_args.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N",
                         help="Number of episodes used for testing")
-    train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N",
-                        help="Total number of steps until training is finished")
-    train_test_args.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N",
+    
+    # update args
+    update_args = parser.add_argument_group('Update')
+    update_args.add_argument("--batch_size", action="store", type=int, default=256, metavar="N",
                         help="Batch size used for training")
-    train_test_args.add_argument("--tau", action="store", type=float, default=0.05, metavar="N",
+    update_args.add_argument("--tau", action="store", type=float, default=5e-3, metavar="N",
                         help="Factor used in soft update of target network")
+    update_args.add_argument("--gamma", action="store", type=float, default=0.95, metavar="N",
+                        help="Discount factor for rewards")
+    update_args.add_argument("--learning_rate", action="store", type=float, default=3e-4, metavar="N",
+                        help="Learn rate for the policy and Q networks")
     
     # buffer args
     buffer_args = parser.add_argument_group('Buffer')
@@ -49,39 +56,35 @@ def cmd_args():
                         help="Define the maximum capacity of the replay buffer")
     buffer_args.add_argument("--clear_buffer", action="store_true", default=False,
                         help="Clear Replay Buffer after every optimization step")
-
+    
+    # network args
+    network_args = parser.add_argument_group('Networks')
+    network_args.add_argument("--hidden_size", action="store", type=int, default=256, metavar="N",
+                        help="Hidden size of the networks")
+    
+    # cql args
+    cql_args = parser.add_argument_group('CQL')
+    cql_args.add_argument("--cql_with_lagrange", action="store", type=float, default=0, metavar="N",  
+                        help="")
+    cql_args.add_argument("--cql_temp", action="store", type=float, default=1.0, metavar="N",
+                        help="")
+    cql_args.add_argument("--cql_weight", action="store", type=float, default=1.0, metavar="N",
+                        help="")
+    cql_args.add_argument("--cql_target_action_gap", action="store", type=float, default=10, metavar="N",
+                        help="")
+    
     # csc args
-    csc_args = parser.add_argument_group('Agent')
-    csc_args.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N",
-                        help="Maximum number of actions sampled during shielding")
-    csc_args.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N",
-                        help="Maximum number of line search update iterations")
-    csc_args.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N",
-                        help="Number of samples to estimate expectations")
+    csc_args = parser.add_argument_group('CSC')
     csc_args.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N",
                         help="Set the value of chi")
     csc_args.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N",
                         help="Set the value of delta")
-    csc_args.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N",
-                        help="Set the value of gamma")
     csc_args.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N",
                         help="Set the value of beta")
     csc_args.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N",
                         help="Set the value of alpha")
     csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N",
                         help="Set the initial value of lambda")
-    csc_args.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N",
-                        help="Learn rate for the safety critic")
-    csc_args.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N",
-                        help="Learn rate for the value network")
-    csc_args.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N",
-                        help="Learn rate for the lambda dual variable")
-    csc_args.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N",
-                        help="Hidden dimension of the networks")
-    csc_args.add_argument("--sigmoid_activation", action="store_true", default=False,
-                        help="Apply sigmoid activation to the safety critics output")
-    csc_args.add_argument("--shielded_action_sampling", action="store_true", default=False,
-                        help="Sample shielded actions when performing parameter updates")
 
     # common args
     common_args = parser.add_argument_group('Common')
@@ -114,10 +117,10 @@ def setup(args):
     with open(os.path.join(output_dir, "config.json"), "w") as file:
         json.dump(args.__dict__, file, indent=2)
 
-    env = env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
+    env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
     buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
     stats = Statistics(writer)
-    agent = CSCAgent(env, args, buffer, stats)
+    agent = CQLSAC(env=env, args=args, stats=stats)
 
     return env, agent, buffer, stats
 
@@ -147,8 +150,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie
     
     while running_episodes > 0:
         # sample and execute actions
-        with torch.no_grad():
-            actions = agent.sample(state, shielded=shielded).cpu().numpy()
+        actions = agent.get_action(state, eval=not train).cpu().numpy()
         next_state, reward, cost, terminated, truncated, info = env.step(actions)
         done = terminated | truncated
         not_done_masked = ((~done) & mask)
@@ -174,7 +176,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie
                     reward[not_done_masked], 
                     cost[not_done_masked], 
                     next_state[not_done_masked],
-                    done[not_done_masked]
+                    terminated[not_done_masked]
                 )
                 buffer.add(
                     state[done_masked], 
@@ -182,7 +184,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie
                     reward[done_masked], 
                     cost[done_masked], 
                     np.stack(info['final_observation'], axis=0)[done_masked],
-                    done[done_masked]
+                    terminated[done_masked]
                 )
 
                 # record finished episodes
@@ -247,10 +249,9 @@ def main(args, env, agent, buffer, stats):
 
             # 2. Perform updates
             for _ in range(args.update_iterations):
-                agent.update()
+                agent.learn(experiences=buffer.sample(n=args.batch_size))
 
             # 3. After update stuff
-            agent.after_updates()
             if args.clear_buffer:
                 buffer.clear()