From 550142d053c2b14f4b5b4ac81b0a88a0f37b0a1b Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Wed, 2 Oct 2024 11:06:47 +0200
Subject: [PATCH] Added optional sigmoid activation, added shielded action
 sampling in update

---
 main.py         |  2 ++
 src/networks.py |  5 +++--
 src/policy.py   | 39 +++++++++++++++++++++++++++++++++------
 3 files changed, 38 insertions(+), 8 deletions(-)

diff --git a/main.py b/main.py
index 995ae33..d72336e 100644
--- a/main.py
+++ b/main.py
@@ -74,6 +74,8 @@ def cmd_args():
                         help="Learn rate for the value network (default: 1e-3)")
     parser.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N",
                         help="Hidden dimension of the networks (default: 32)")
+    parser.add_argument("--sigmoid_activation", action="store_true", default=False,
+                        help="Apply sigmoid activation to the safety critics output (default: False)")
 
     # common args
     parser.add_argument("--seed", action="store", type=int, default=42, metavar="N",
diff --git a/src/networks.py b/src/networks.py
index 7d86855..fbd1072 100644
--- a/src/networks.py
+++ b/src/networks.py
@@ -36,13 +36,13 @@ class ValueNetwork(nn.Module):
 
 
 class QNetwork(nn.Module):
-    def __init__(self, num_inputs, num_actions, hidden_dim):
+    def __init__(self, num_inputs, num_actions, hidden_dim, sigmoid_activation=False):
         super().__init__()
 
-        # Q1 architecture
         self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
         self.linear2 = nn.Linear(hidden_dim, hidden_dim)
         self.linear3 = nn.Linear(hidden_dim, 1)
+        self.last_activation = F.sigmoid if sigmoid_activation else nn.Identity
 
         self.apply(weights_init_)
 
@@ -52,6 +52,7 @@ class QNetwork(nn.Module):
         x1 = F.relu(self.linear1(xu))
         x1 = F.relu(self.linear2(x1))
         x1 = self.linear3(x1)
+        x1 = self.last_activation(x1)
 
         return x1
 
diff --git a/src/policy.py b/src/policy.py
index 66cd711..364c16b 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -28,11 +28,11 @@ class CSCAgent():
         num_inputs = env.observation_space.shape[-1]
         num_actions = env.action_space.shape[-1]
         self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
-        self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim).to(self._device)
+        self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
         self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
 
         self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
-        self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim).to(self._device)
+        self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
         self._target_value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
         self.soft_update(self._policy_old, self._policy, tau=1)
         self.soft_update(self._target_safety_critic, self._safety_critic, tau=1)
@@ -57,7 +57,7 @@ class CSCAgent():
         cost_actions = self._safety_critic.forward(states, actions)
         cost_states = torch.zeros_like(cost_actions)
         for _ in range(self._expectation_estimation_samples):
-            a = self._policy_old.sample(states).squeeze(0).detach()
+            a = self._action_for_update(policy=self._policy_old, state=states, shielded=True).squeeze(0).detach()
             cost_states += self._safety_critic.forward(states, a)
         cost_states /= self._expectation_estimation_samples
         return cost_actions - cost_states
@@ -166,13 +166,13 @@ class CSCAgent():
     def _update_safety_critic(self, states, actions, costs, next_states):
         safety_sa_env = self._safety_critic.forward(states, actions)
 
-        a = self._policy.sample(states).squeeze(0).detach()
+        a = self._action_for_update(policy=self._policy, state=states, shielded=True).squeeze(0).detach()
         safety_s_env_a_p = self._safety_critic.forward(states, a)
 
         with torch.no_grad():
             safety_next_state = torch.zeros_like(safety_sa_env)
             for _ in range(self._expectation_estimation_samples):
-                a = self._policy.sample(next_states).squeeze(0).detach()
+                a = self._action_for_update(policy=self._policy, state=next_states, shielded=True).squeeze(0).detach()
                 safety_next_state += self._target_safety_critic.forward(next_states, a)
             safety_next_state /= self._expectation_estimation_samples
             safety_next_state = costs.view((-1, 1)) + self._gamma * safety_next_state
@@ -243,4 +243,31 @@ class CSCAgent():
         
         else:
             action = self._policy.sample(state)
-            return action.squeeze(0).cpu().numpy()
\ No newline at end of file
+            return action.squeeze(0).cpu().numpy()
+
+    @torch.no_grad
+    def _action_for_update(self, policy, state, shielded=True, return_dist=False):
+        if shielded:
+            dist = policy.distribution(state)
+            action = dist.sample((self._shield_iterations,))
+
+            state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1))
+            unsafety = self._safety_critic.forward(state, action).squeeze(2).permute(1,0)
+
+            mask = unsafety <= self._unsafety_threshold
+            argmax = mask.int().argmax(dim=1)   # col idx of first "potentially" safe action
+            is_zero = mask[torch.arange(0, state.shape[1]), argmax] == 0    # check if in a row every col is unsafe
+
+            action = action.permute(1, 0, 2)
+            result = torch.zeros((state.shape[1], action.shape[-1]), device=self._device)
+            result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...]     # minimum in case only unsafe actions
+            result[~is_zero] = action[~is_zero, argmax[~is_zero], ...]    # else first safe action
+
+            result = result
+        else:
+            dist = policy.distribution(state)
+            result = dist.sample().squeeze(0)
+        
+        if return_dist:
+            return result, dist
+        return result
\ No newline at end of file
-- 
GitLab