diff --git a/main.py b/main.py
index 24fc27d762551b5ee5b225ee9055f08f4ad5f50f..a59a8b464e9c6cd64c543dc8adb2a8bdcf41d0cc 100644
--- a/main.py
+++ b/main.py
@@ -42,6 +42,8 @@ def cmd_args():
                         help="Batch size used for training (default: 1024)")
     parser.add_argument("--tau", action="store", type=float, default=0.05, metavar="N",
                         help="Factor used in soft update of target network (default: 0.05)")
+    parser.add_argument("--shielded_action_sampling", action="store_true", default=False,
+                        help="Sample shielded actions when performing parameter updates (default: False)")
 
     # buffer args
     parser.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N",
diff --git a/src/policy.py b/src/policy.py
index de73a72050bb6e03382120a764fcb9d5c922f70a..722070dc3e6a21cfada441a2a33e1f8244bbc1f0 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -25,6 +25,7 @@ class CSCAgent():
         self._tau = args.tau
         self._hidden_dim = args.hidden_dim
         self._lambda_lr = args.csc_lambda_lr
+        self._shielded_action_sampling = args.shielded_action_sampling
 
         num_inputs = env.observation_space.shape[-1]
         num_actions = env.action_space.shape[-1]
@@ -58,7 +59,7 @@ class CSCAgent():
         cost_actions = self._safety_critic.forward(states, actions)
         cost_states = torch.zeros_like(cost_actions)
         for _ in range(self._expectation_estimation_samples):
-            a = self._action_for_update(policy=self._policy_old, state=states, shielded=True).squeeze(0).detach()
+            a = self._action_for_update(policy=self._policy_old, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach()
             cost_states += self._safety_critic.forward(states, a)
         cost_states /= self._expectation_estimation_samples
         return cost_actions - cost_states
@@ -167,13 +168,13 @@ class CSCAgent():
     def _update_safety_critic(self, states, actions, costs, next_states):
         safety_sa_env = self._safety_critic.forward(states, actions)
 
-        a = self._action_for_update(policy=self._policy, state=states, shielded=True).squeeze(0).detach()
+        a = self._action_for_update(policy=self._policy, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach()
         safety_s_env_a_p = self._safety_critic.forward(states, a)
 
         with torch.no_grad():
             safety_next_state = torch.zeros_like(safety_sa_env)
             for _ in range(self._expectation_estimation_samples):
-                a = self._action_for_update(policy=self._policy, state=next_states, shielded=True).squeeze(0).detach()
+                a = self._action_for_update(policy=self._policy, state=next_states, shielded=self._shielded_action_sampling).squeeze(0).detach()
                 safety_next_state += self._target_safety_critic.forward(next_states, a)
             safety_next_state /= self._expectation_estimation_samples
             safety_next_state = costs.view((-1, 1)) + self._gamma * safety_next_state