From e5774eda25ca97d15d43933844871bf2b875c9a3 Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Thu, 3 Oct 2024 13:36:41 +0200 Subject: [PATCH] Added shielded_action_sampling --- main.py | 2 ++ src/policy.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 24fc27d..a59a8b4 100644 --- a/main.py +++ b/main.py @@ -42,6 +42,8 @@ def cmd_args(): help="Batch size used for training (default: 1024)") parser.add_argument("--tau", action="store", type=float, default=0.05, metavar="N", help="Factor used in soft update of target network (default: 0.05)") + parser.add_argument("--shielded_action_sampling", action="store_true", default=False, + help="Sample shielded actions when performing parameter updates (default: False)") # buffer args parser.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N", diff --git a/src/policy.py b/src/policy.py index de73a72..722070d 100644 --- a/src/policy.py +++ b/src/policy.py @@ -25,6 +25,7 @@ class CSCAgent(): self._tau = args.tau self._hidden_dim = args.hidden_dim self._lambda_lr = args.csc_lambda_lr + self._shielded_action_sampling = args.shielded_action_sampling num_inputs = env.observation_space.shape[-1] num_actions = env.action_space.shape[-1] @@ -58,7 +59,7 @@ class CSCAgent(): cost_actions = self._safety_critic.forward(states, actions) cost_states = torch.zeros_like(cost_actions) for _ in range(self._expectation_estimation_samples): - a = self._action_for_update(policy=self._policy_old, state=states, shielded=True).squeeze(0).detach() + a = self._action_for_update(policy=self._policy_old, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach() cost_states += self._safety_critic.forward(states, a) cost_states /= self._expectation_estimation_samples return cost_actions - cost_states @@ -167,13 +168,13 @@ class CSCAgent(): def _update_safety_critic(self, states, actions, costs, next_states): safety_sa_env = self._safety_critic.forward(states, actions) - a = self._action_for_update(policy=self._policy, state=states, shielded=True).squeeze(0).detach() + a = self._action_for_update(policy=self._policy, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach() safety_s_env_a_p = self._safety_critic.forward(states, a) with torch.no_grad(): safety_next_state = torch.zeros_like(safety_sa_env) for _ in range(self._expectation_estimation_samples): - a = self._action_for_update(policy=self._policy, state=next_states, shielded=True).squeeze(0).detach() + a = self._action_for_update(policy=self._policy, state=next_states, shielded=self._shielded_action_sampling).squeeze(0).detach() safety_next_state += self._target_safety_critic.forward(next_states, a) safety_next_state /= self._expectation_estimation_samples safety_next_state = costs.view((-1, 1)) + self._gamma * safety_next_state -- GitLab