diff --git a/main.py b/main.py
index a32e903157287163fb07253f67282d58e4646247..8215e83da8493c2b95087fe51331d78bf901b928 100644
--- a/main.py
+++ b/main.py
@@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
 
 from src.stats import Statistics
 from src.buffer import ReplayBuffer
-from src.cql_sac.agent import CQLSAC
+from src.cql_sac.agent import CSCCQLSAC
 from src.environment import create_environment
 
 ##################
@@ -64,7 +64,7 @@ def cmd_args():
     
     # cql args
     cql_args = parser.add_argument_group('CQL')
-    cql_args.add_argument("--cql_with_lagrange", action="store", type=float, default=0, metavar="N",  
+    cql_args.add_argument("--cql_with_lagrange", action="store_true", default=False,
                         help="")
     cql_args.add_argument("--cql_temp", action="store", type=float, default=1.0, metavar="N",
                         help="")
@@ -121,7 +121,7 @@ def setup(args):
     env = create_environment(args=args)
     buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
     stats = Statistics(writer=writer)
-    agent = CQLSAC(env=env, args=args, stats=stats)
+    agent = CSCCQLSAC(env=env, args=args, stats=stats)
 
     return env, agent, buffer, stats
 
@@ -195,6 +195,8 @@ def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train
                 stats.log_tensorboard("train/steps", episode_steps[done_masked], ticks)
                 stats.log_tensorboard("train/unsafe", (episode_cost[done_masked] > args.cost_limit).astype(np.uint8), ticks)
 
+                stats.log_train_history((episode_cost[done_masked] > args.cost_limit).astype(np.uint8))
+
                 stats.total_train_episodes += done_masked_count
                 stats.total_train_steps += episode_steps[done_masked].sum()
                 stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum()
@@ -305,8 +307,8 @@ def main(args, env, agent, buffer, stats):
                 buffer.clear()
 
         # Test loop (shielded and unshielded)
-        for shielded in [True, False]:
-            run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded)
+        # for shielded in [True, False]:
+        run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=True)
 
 if __name__ == '__main__':
     args = cmd_args()
diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py
index 85abbdca76bdf71552c718fe603f60bd645c5662..3e6e4cb0cef61d21173db14de1c7c44b4b89f447 100644
--- a/src/cql_sac/agent.py
+++ b/src/cql_sac/agent.py
@@ -9,7 +9,7 @@ import math
 import copy
 
 
-class CQLSAC(nn.Module):
+class CSCCQLSAC(nn.Module):
     """Interacts with and learns from the environment."""
     
     def __init__(self,
@@ -24,7 +24,9 @@ class CQLSAC(nn.Module):
             env : the vector environment
             args : the argparse arguments
         """
-        super(CQLSAC, self).__init__()
+        super(CSCCQLSAC, self).__init__()
+        self.stats = stats
+
         state_size = env.observation_space.shape[-1]
         action_size = env.action_space.shape[-1]
         hidden_size = args.hidden_size
@@ -39,7 +41,7 @@ class CQLSAC(nn.Module):
 
         self.target_entropy = -action_size  # -dim(A)
 
-        self.log_alpha = torch.tensor([0.0], requires_grad=True)
+        self.log_alpha = torch.tensor([0.0], requires_grad=True, device=self.device)
         self.alpha = self.log_alpha.exp().detach()
         self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=self.learning_rate) 
         
@@ -48,19 +50,41 @@ class CQLSAC(nn.Module):
         self.cql_temp = args.cql_temp
         self.cql_weight = args.cql_weight
         self.cql_target_action_gap = args.cql_target_action_gap
-        self.cql_log_alpha = torch.zeros(1, requires_grad=True)
+        self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
         self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate) 
+
+        # CSC params
+        self.csc_shield_iterations = 100
+        self.csc_alpha = args.csc_alpha
+        self.csc_beta = args.csc_beta
+        self.csc_delta = args.csc_delta
+        self.csc_chi = args.csc_chi
+        self.csc_avg_unsafe = args.csc_chi
+
+        self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device)
+        self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) 
         
         # Actor Network 
         self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device)
-        self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=self.learning_rate)     
+        self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=self.learning_rate)
+
+        # Safety Critic Network (w/ Target Network)
+        self.safety_critic1 = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.safety_critic2 = Critic(state_size, action_size, hidden_size).to(self.device)
+        
+        self.safety_critic1_target = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.safety_critic1_target.load_state_dict(self.safety_critic1.state_dict())
+
+        self.safety_critic2_target = Critic(state_size, action_size, hidden_size).to(self.device)
+        self.safety_critic2_target.load_state_dict(self.safety_critic2.state_dict())
+
+        self.safety_critic1_optimizer = optim.Adam(self.safety_critic1.parameters(), lr=self.learning_rate)
+        self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate) 
         
         # Critic Network (w/ Target Network)
         self.critic1 = Critic(state_size, action_size, hidden_size).to(self.device)
         self.critic2 = Critic(state_size, action_size, hidden_size).to(self.device)
         
-        assert self.critic1.parameters() != self.critic2.parameters()
-        
         self.critic1_target = Critic(state_size, action_size, hidden_size).to(self.device)
         self.critic1_target.load_state_dict(self.critic1.state_dict())
 
@@ -72,32 +96,54 @@ class CQLSAC(nn.Module):
 
     
     def get_action(self, state, eval=False):
-        """Returns actions for given state as per current policy."""
-        state = torch.from_numpy(state).float().to(self.device)
+        """ 
+        Returns shielded actions for given state as per current policy. 
         
-        with torch.no_grad():
-            if eval:
-                action = self.actor_local.get_det_action(state)
-            else:
-                action = self.actor_local.get_action(state)
-        return action
+        Note: eval is currently ignored.
+        """
+        state = torch.from_numpy(state).float().to(self.device)
+
+        batch_size = state.shape[0]
+        unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe)
+        unsafety_best = torch.full((batch_size, ), fill_value=unsafety_threshold+1).to(self.device)
+        action_best = torch.zeros(batch_size, self.action_size).to(self.device)
+
+        # Run at max 'csc_shield_iterations' iterations to find safe action
+        for _ in range(self.csc_shield_iterations):
+            # If all actions are already safe, break
+            mask_safe = unsafety_best <= unsafety_threshold
+            if mask_safe.all(): break
+
+            # Sample new actions
+            with torch.no_grad():
+                action = self.actor_local.get_action(state).to(self.device)
+            
+            # Estimate safety of new actions
+            q1 = self.safety_critic1(state, action)
+            q2 = self.safety_critic2(state, action)
+            unsafety = torch.min(q1, q2).squeeze(1)
+
+            # Update best actions if they are still unsafe and new actions are safer
+            mask_update = (~mask_safe) & (unsafety < unsafety_best)
+            unsafety_best[mask_update] = unsafety[mask_update]
+            action_best[mask_update] = action[mask_update]
+
+        return action_best
 
     def calc_policy_loss(self, states, alpha):
         actions_pred, log_pis = self.actor_local.evaluate(states)
 
         q1 = self.critic1(states, actions_pred.squeeze(0))
         q2 = self.critic2(states, actions_pred.squeeze(0))
-        min_Q = torch.min(q1,q2).cpu()
-        actor_loss = ((alpha * log_pis.cpu() - min_Q )).mean()
+        min_Q = torch.min(q1,q2)
+        actor_loss = ((alpha * log_pis - min_Q )).mean()
         return actor_loss, log_pis
 
     def _compute_policy_values(self, obs_pi, obs_q):
         #with torch.no_grad():
         actions_pred, log_pis = self.actor_local.evaluate(obs_pi)
-        
-        qs1 = self.critic1(obs_q, actions_pred)
-        qs2 = self.critic2(obs_q, actions_pred)
-        
+        qs1 = self.safety_critic1(obs_q, actions_pred)
+        qs2 = self.safety_critic2(obs_q, actions_pred)
         return qs1 - log_pis.detach(), qs2 - log_pis.detach()
     
     def _compute_random_values(self, obs, actions, critic):
@@ -118,6 +164,7 @@ class CQLSAC(nn.Module):
             experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples 
             gamma (float): discount factor
         """
+        self.csc_avg_unsafe = self.stats.train_unsafe_avg
         states, actions, rewards, costs, next_states, dones = experiences
 
         states = torch.from_numpy(states).float().to(self.device)
@@ -127,49 +174,63 @@ class CQLSAC(nn.Module):
         next_states = torch.from_numpy(next_states).float().to(self.device)
         dones = torch.from_numpy(dones).float().to(self.device).view(-1, 1)
 
-        # ---------------------------- update actor ---------------------------- #
-        current_alpha = copy.deepcopy(self.alpha)
-        actor_loss, log_pis = self.calc_policy_loss(states, current_alpha)
-        self.actor_optimizer.zero_grad()
-        actor_loss.backward()
-        self.actor_optimizer.step()
-        
-        # Compute alpha loss
-        alpha_loss = - (self.log_alpha.exp() * (log_pis.cpu() + self.target_entropy).detach().cpu()).mean()
-        self.alpha_optimizer.zero_grad()
-        alpha_loss.backward()
-        self.alpha_optimizer.step()
-        self.alpha = self.log_alpha.exp().detach()
-
         # ---------------------------- update critic ---------------------------- #
         # Get predicted next-state actions and Q values from target models
         with torch.no_grad():
             next_action, new_log_pi = self.actor_local.evaluate(next_states)
             Q_target1_next = self.critic1_target(next_states, next_action)
             Q_target2_next = self.critic2_target(next_states, next_action)
-            Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha.to(self.device) * new_log_pi
+            Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha * new_log_pi
             # Compute Q targets for current states (y_i)
             Q_targets = rewards + (self.gamma * (1 - dones) * Q_target_next)
 
-
         # Compute critic loss
         q1 = self.critic1(states, actions)
         q2 = self.critic2(states, actions)
 
         critic1_loss = F.mse_loss(q1, Q_targets)
         critic2_loss = F.mse_loss(q2, Q_targets)
+
+        # Update critics
+        # critic 1
+        self.critic1_optimizer.zero_grad()
+        critic1_loss.backward(retain_graph=True)
+        clip_grad_norm_(self.critic1.parameters(), self.clip_grad_param)
+        self.critic1_optimizer.step()
+        # critic 2
+        self.critic2_optimizer.zero_grad()
+        critic2_loss.backward()
+        clip_grad_norm_(self.critic2.parameters(), self.clip_grad_param)
+        self.critic2_optimizer.step()
+
+        # ---------------------------- update safety critic ---------------------------- #
+        # Get predicted next-state actions and Q values from target models
+        with torch.no_grad():
+            next_action, new_log_pi = self.actor_local.evaluate(next_states)
+            Q_target1_next = self.safety_critic1_target(next_states, next_action)
+            Q_target2_next = self.safety_critic2_target(next_states, next_action)
+            Q_target_next = torch.min(Q_target1_next, Q_target2_next) # - self.alpha * new_log_pi
+            # Compute Q targets for current states (y_i)
+            Q_targets = costs + (self.gamma * (1 - dones) * Q_target_next)
+
+        # Compute safety_critic loss
+        q1 = self.safety_critic1(states, actions)
+        q2 = self.safety_critic2(states, actions)
+
+        safety_critic1_loss = F.mse_loss(q1, Q_targets)
+        safety_critic2_loss = F.mse_loss(q2, Q_targets)
         
         # CQL addon
-        random_actions = torch.FloatTensor(q1.shape[0] * 10, actions.shape[-1]).uniform_(-1, 1).to(self.device)
-        num_repeat = int (random_actions.shape[0] / states.shape[0])
+        num_repeat = 10
+        random_actions = torch.FloatTensor(q1.shape[0] * num_repeat, actions.shape[-1]).uniform_(-1, 1).to(self.device)
         temp_states = states.unsqueeze(1).repeat(1, num_repeat, 1).view(states.shape[0] * num_repeat, states.shape[1])
         temp_next_states = next_states.unsqueeze(1).repeat(1, num_repeat, 1).view(next_states.shape[0] * num_repeat, next_states.shape[1])
         
         current_pi_values1, current_pi_values2  = self._compute_policy_values(temp_states, temp_states)
         next_pi_values1, next_pi_values2 = self._compute_policy_values(temp_next_states, temp_states)
         
-        random_values1 = self._compute_random_values(temp_states, random_actions, self.critic1).reshape(states.shape[0], num_repeat, 1)
-        random_values2 = self._compute_random_values(temp_states, random_actions, self.critic2).reshape(states.shape[0], num_repeat, 1)
+        random_values1 = self._compute_random_values(temp_states, random_actions, self.safety_critic1).reshape(states.shape[0], num_repeat, 1)
+        random_values2 = self._compute_random_values(temp_states, random_actions, self.safety_critic2).reshape(states.shape[0], num_repeat, 1)
         
         current_pi_values1 = current_pi_values1.reshape(states.shape[0], num_repeat, 1)
         current_pi_values2 = current_pi_values2.reshape(states.shape[0], num_repeat, 1)
@@ -183,14 +244,14 @@ class CQLSAC(nn.Module):
         assert cat_q1.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q1 instead has shape: {cat_q1.shape}"
         assert cat_q2.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q2 instead has shape: {cat_q2.shape}"
         
-
-        cql1_scaled_loss = ((torch.logsumexp(cat_q1 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) - q1.mean()) * self.cql_weight
-        cql2_scaled_loss = ((torch.logsumexp(cat_q2 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) - q2.mean()) * self.cql_weight
+        # flipped sign of cql1_scaled_loss and cql2_scaled_loss
+        cql1_scaled_loss = -(torch.logsumexp(cat_q1 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q1.mean() * self.cql_weight)
+        cql2_scaled_loss = -(torch.logsumexp(cat_q2 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q2.mean() * self.cql_weight)
         
         cql_alpha_loss = torch.FloatTensor([0.0])
-        cql_alpha = torch.FloatTensor([0.0])
+        cql_alpha = torch.FloatTensor([1.0])
         if self.cql_with_lagrange:
-            cql_alpha = torch.clamp(self.cql_log_alpha.exp(), min=0.0, max=1000000.0).to(self.device)
+            cql_alpha = torch.clamp(self.cql_log_alpha.exp(), min=0.0, max=1000000.0)
             cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.cql_target_action_gap)
             cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.cql_target_action_gap)
 
@@ -199,27 +260,95 @@ class CQLSAC(nn.Module):
             cql_alpha_loss.backward(retain_graph=True)
             self.cql_alpha_optimizer.step()
         
-        total_c1_loss = critic1_loss + cql1_scaled_loss
-        total_c2_loss = critic2_loss + cql2_scaled_loss
+        total_c1_loss = safety_critic1_loss + cql1_scaled_loss
+        total_c2_loss = safety_critic2_loss + cql2_scaled_loss
         
         
-        # Update critics
-        # critic 1
-        self.critic1_optimizer.zero_grad()
+        # Update safety_critics
+        # safety_critic 1
+        self.safety_critic1_optimizer.zero_grad()
         total_c1_loss.backward(retain_graph=True)
-        clip_grad_norm_(self.critic1.parameters(), self.clip_grad_param)
-        self.critic1_optimizer.step()
-        # critic 2
-        self.critic2_optimizer.zero_grad()
+        clip_grad_norm_(self.safety_critic1.parameters(), self.clip_grad_param)
+        self.safety_critic1_optimizer.step()
+        # safety_critic 2
+        self.safety_critic2_optimizer.zero_grad()
         total_c2_loss.backward()
-        clip_grad_norm_(self.critic2.parameters(), self.clip_grad_param)
-        self.critic2_optimizer.step()
+        clip_grad_norm_(self.safety_critic2.parameters(), self.clip_grad_param)
+        self.safety_critic2_optimizer.step()
+
+        # ---------------------------- update csc lambda ---------------------------- #
+        # Estimate cost advantage
+        with torch.no_grad():
+            q1 = self.safety_critic1(states, actions)
+            q2 = self.safety_critic2(states, actions)
+            v = torch.min(q1, q2)
+
+            new_action, new_log_pi = self.actor_local.evaluate(states)
+            q1 = self.safety_critic1(states, new_action)
+            q2 = self.safety_critic2(states, new_action)
+            q = torch.min(q1, q2)
+
+            cost_advantage = (q - v).mean()
+
+        # Compute csc lambda loss
+        csc_lambda_loss = -self.csc_lambda*(self.csc_avg_unsafe + (1 / (1 - self.gamma)) * cost_advantage - self.csc_chi)
+
+        self.csc_lambda_optimizer.zero_grad()
+        csc_lambda_loss.backward()
+        self.csc_lambda_optimizer.step()
+
+        # ---------------------------- update actor ---------------------------- #
+        # Estimate reward advantage
+        q1 = self.critic1(states, actions)
+        q2 = self.critic2(states, actions)
+        v = torch.min(q1, q2).detach()
+
+        new_action, new_log_pi = self.actor_local.evaluate(states)
+        q1 = self.critic1(states, new_action)
+        q2 = self.critic2(states, new_action)
+        q = torch.min(q1, q2)
+
+        reward_advantage = q - v
+
+        # Optimize actor
+        actor_loss = ((self.alpha * new_log_pi - reward_advantage)).mean()
+        self.actor_optimizer.zero_grad()
+        actor_loss.backward()
+        self.actor_optimizer.step()
+        
+        # Compute alpha loss
+        alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean()
+        self.alpha_optimizer.zero_grad()
+        alpha_loss.backward()
+        self.alpha_optimizer.step()
+        self.alpha = self.log_alpha.exp().detach()
 
         # ----------------------- update target networks ----------------------- #
         self.soft_update(self.critic1, self.critic1_target)
         self.soft_update(self.critic2, self.critic2_target)
+        self.soft_update(self.safety_critic1, self.safety_critic1_target)
+        self.soft_update(self.safety_critic2, self.safety_critic2_target)
         
-        return actor_loss.item(), alpha_loss.item(), critic1_loss.item(), critic2_loss.item(), cql1_scaled_loss.item(), cql2_scaled_loss.item(), current_alpha, cql_alpha_loss.item(), cql_alpha.item()
+        # ----------------------- update stats ----------------------- #
+        data = {
+            "actor_loss": actor_loss.item(),
+            "alpha_loss": alpha_loss.item(),
+            "alpha": self.alpha.item(),
+            "lambda_loss": csc_lambda_loss.item(),
+            "lambda": self.csc_lambda.item(),
+            "critic1_loss": critic1_loss.item(),
+            "critic2_loss": critic2_loss.item(),
+            "cql1_scaled_loss": cql1_scaled_loss.item(),
+            "cql2_scaled_loss": cql2_scaled_loss.item(),
+            "total_c1_loss": total_c1_loss.item(),
+            "total_c2_loss": total_c2_loss.item(),
+            "cql_alpha_loss": cql_alpha_loss.item(),
+            "cql_alpha": cql_alpha.item()
+        }
+        if self.stats.total_updates % 8 == 0:
+            self.stats.log_update_tensorboard(data)
+        self.stats.total_updates += 1
+        return data
 
     def soft_update(self, local_model , target_model):
         """Soft update model parameters.
diff --git a/src/stats.py b/src/stats.py
index 8b7494dec164ecaff2614582eb49da9c343b9aef..4f263bf73e5ef7fb25774d047297b8fcbfce7bf9 100644
--- a/src/stats.py
+++ b/src/stats.py
@@ -46,6 +46,11 @@ class Statistics:
         self.test_reward_history.clear()
         self.test_cost_history.clear()
         self.test_unsafe_history.clear()
+    
+    def log_update_tensorboard(self, data:dict):
+        for k, v in data.items():
+            name = f"update/{k}"
+            self.writer.add_scalar(name, v, self.total_updates)
 
     def log_tensorboard(self, name: str, values: np.ndarray, ticks: np.ndarray):
         if values.size == 1: