From 490b0ef28f179dc9916420e5b312257743c69e4e Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Tue, 25 Feb 2025 17:07:11 +0100 Subject: [PATCH] Added faster shielded action sampling --- main.py | 2 +- src/cql_sac/agent.py | 55 +++++++++++++++++++++++--------------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/main.py b/main.py index 8215e83..4766320 100644 --- a/main.py +++ b/main.py @@ -151,7 +151,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train while running_episodes > 0: # sample and execute actions - actions = agent.get_action(state, eval=not train).cpu().numpy() + actions = agent.get_action(state, shielded=shielded).cpu().numpy() next_state, reward, cost, terminated, truncated, info = env.step(actions) done = terminated | truncated not_done_masked = ((~done) & mask) diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py index 3e6e4cb..6f10681 100644 --- a/src/cql_sac/agent.py +++ b/src/cql_sac/agent.py @@ -94,41 +94,44 @@ class CSCCQLSAC(nn.Module): self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=self.learning_rate) self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate) - - def get_action(self, state, eval=False): + def get_action(self, state, shielded=True): """ Returns shielded actions for given state as per current policy. - - Note: eval is currently ignored. """ state = torch.from_numpy(state).float().to(self.device) - - batch_size = state.shape[0] - unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe) - unsafety_best = torch.full((batch_size, ), fill_value=unsafety_threshold+1).to(self.device) - action_best = torch.zeros(batch_size, self.action_size).to(self.device) - - # Run at max 'csc_shield_iterations' iterations to find safe action - for _ in range(self.csc_shield_iterations): - # If all actions are already safe, break - mask_safe = unsafety_best <= unsafety_threshold - if mask_safe.all(): break - - # Sample new actions + + if shielded: + # Repeat state, resulting shape: (shield_iterations, batch_size, state_size) + state = state.repeat((self.csc_shield_iterations, 1)).reshape(self.csc_shield_iterations, *state.shape) + unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe) + + # Sample all 'csc_shield_iterations' actions at once for every state with torch.no_grad(): action = self.actor_local.get_action(state).to(self.device) - # Estimate safety of new actions + # Estimate unsafety of all actions q1 = self.safety_critic1(state, action) q2 = self.safety_critic2(state, action) - unsafety = torch.min(q1, q2).squeeze(1) - - # Update best actions if they are still unsafe and new actions are safer - mask_update = (~mask_safe) & (unsafety < unsafety_best) - unsafety_best[mask_update] = unsafety[mask_update] - action_best[mask_update] = action[mask_update] - - return action_best + unsafety = torch.min(q1, q2).squeeze(2) + + # Check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state + mask_safe = unsafety <= unsafety_threshold + idx_first_safe = mask_safe.int().argmax(dim=0) + + # If all actions are unsafe (> threshold), argmax will return the first unsafe as mask_safe[:,,] == 0 everywhere + mask_all_unsafe = (~mask_safe[idx_first_safe, torch.arange(0, mask_safe.shape[1])]) + + # We now build an idx to access the action tensor as follows: + # If there was at least one safe action, idx_first_safe will be the first safe action for each state + # If there was no safe action, we retrieve the action with minimum unsafety for each state + idx_0 = idx_first_safe + idx_0[mask_all_unsafe] = unsafety[:, mask_all_unsafe].argmin(dim=0) + idx_1 = torch.arange(0, mask_safe.shape[1]) + + # Access action tensor and return + return action[idx_0, idx_1, :] + else: + return self.actor_local.get_action(state).to(self.device) def calc_policy_loss(self, states, alpha): actions_pred, log_pis = self.actor_local.evaluate(states) -- GitLab