diff --git a/main.py b/main.py
index 8215e83da8493c2b95087fe51331d78bf901b928..4766320bf70fffd9b9ca931db9e82f4b4cab987d 100644
--- a/main.py
+++ b/main.py
@@ -151,7 +151,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train
     
     while running_episodes > 0:
         # sample and execute actions
-        actions = agent.get_action(state, eval=not train).cpu().numpy()
+        actions = agent.get_action(state, shielded=shielded).cpu().numpy()
         next_state, reward, cost, terminated, truncated, info = env.step(actions)
         done = terminated | truncated
         not_done_masked = ((~done) & mask)
diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py
index 3e6e4cb0cef61d21173db14de1c7c44b4b89f447..6f10681932d821e82505abbafb11d6896a4d88d4 100644
--- a/src/cql_sac/agent.py
+++ b/src/cql_sac/agent.py
@@ -94,41 +94,44 @@ class CSCCQLSAC(nn.Module):
         self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=self.learning_rate)
         self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate) 
 
-    
-    def get_action(self, state, eval=False):
+    def get_action(self, state, shielded=True):
         """ 
         Returns shielded actions for given state as per current policy. 
-        
-        Note: eval is currently ignored.
         """
         state = torch.from_numpy(state).float().to(self.device)
-
-        batch_size = state.shape[0]
-        unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe)
-        unsafety_best = torch.full((batch_size, ), fill_value=unsafety_threshold+1).to(self.device)
-        action_best = torch.zeros(batch_size, self.action_size).to(self.device)
-
-        # Run at max 'csc_shield_iterations' iterations to find safe action
-        for _ in range(self.csc_shield_iterations):
-            # If all actions are already safe, break
-            mask_safe = unsafety_best <= unsafety_threshold
-            if mask_safe.all(): break
-
-            # Sample new actions
+        
+        if shielded:
+            # Repeat state, resulting shape: (shield_iterations, batch_size, state_size)
+            state = state.repeat((self.csc_shield_iterations, 1)).reshape(self.csc_shield_iterations, *state.shape)
+            unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe)
+            
+            # Sample all 'csc_shield_iterations' actions at once for every state
             with torch.no_grad():
                 action = self.actor_local.get_action(state).to(self.device)
             
-            # Estimate safety of new actions
+            # Estimate unsafety of all actions
             q1 = self.safety_critic1(state, action)
             q2 = self.safety_critic2(state, action)
-            unsafety = torch.min(q1, q2).squeeze(1)
-
-            # Update best actions if they are still unsafe and new actions are safer
-            mask_update = (~mask_safe) & (unsafety < unsafety_best)
-            unsafety_best[mask_update] = unsafety[mask_update]
-            action_best[mask_update] = action[mask_update]
-
-        return action_best
+            unsafety = torch.min(q1, q2).squeeze(2)
+
+            # Check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state
+            mask_safe = unsafety <= unsafety_threshold
+            idx_first_safe = mask_safe.int().argmax(dim=0)
+
+            # If all actions are unsafe (> threshold), argmax will return the first unsafe as mask_safe[:,,] == 0 everywhere
+            mask_all_unsafe = (~mask_safe[idx_first_safe, torch.arange(0, mask_safe.shape[1])])
+
+            # We now build an idx to access the action tensor as follows:
+            # If there was at least one safe action, idx_first_safe will be the first safe action for each state
+            # If there was no safe action, we retrieve the action with minimum unsafety for each state
+            idx_0 = idx_first_safe
+            idx_0[mask_all_unsafe] = unsafety[:, mask_all_unsafe].argmin(dim=0)
+            idx_1 = torch.arange(0, mask_safe.shape[1])
+
+            # Access action tensor and return
+            return action[idx_0, idx_1, :]
+        else:
+            return self.actor_local.get_action(state).to(self.device)
 
     def calc_policy_loss(self, states, alpha):
         actions_pred, log_pis = self.actor_local.evaluate(states)