From 937cac7d5b14ff1cb300fcdbd608db3777046fd7 Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Sat, 8 Mar 2025 00:17:35 +0100 Subject: [PATCH] Added csc actor update --- main.py | 10 +++ src/cql_sac/agent.py | 206 +++++++++++++++++++++++++++++++------------ 2 files changed, 160 insertions(+), 56 deletions(-) diff --git a/main.py b/main.py index bff7053..496acce 100644 --- a/main.py +++ b/main.py @@ -85,6 +85,10 @@ def cmd_args(): help="Set the value of alpha") csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N", help="Set the initial value of lambda") + csc_args.add_argument("--csc_shield_iterations", action="store", type=int, default=100, metavar="N", + help="Set the number of sampled actions during shielding") + csc_args.add_argument("--csc_update_iterations", action="store", type=int, default=20, metavar="N", + help="Set the number of linear backtracking iterations") # common args common_args = parser.add_argument_group('Common') @@ -96,6 +100,11 @@ def cmd_args(): help="Set the output and log directory path") common_args.add_argument("--num_threads", action="store", type=int, default=1, metavar="N", help="Set the maximum number of threads for pytorch and numpy") + + # dev args + dev_args = parser.add_argument_group('Development') + dev_args.add_argument("--fisher_inversion", action="store", type=str, default="pinv", metavar="INV", + help="Inversion mode for fisher matrix: pinv, inv, none") args = parser.parse_args() return args @@ -109,6 +118,7 @@ def setup(args): Performs setup like fixing seeds, initializing env and agent, buffer and stats. """ torch.set_num_threads(args.num_threads) + torch.set_default_dtype(torch.double) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py index 8ed7054..827c28d 100644 --- a/src/cql_sac/agent.py +++ b/src/cql_sac/agent.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from .networks import Critic, Actor import math - +import copy class CSCCQLSAC(nn.Module): """Interacts with and learns from the environment.""" @@ -49,10 +49,11 @@ class CSCCQLSAC(nn.Module): self.cql_weight = args.cql_weight self.cql_target_action_gap = args.cql_target_action_gap self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.device) - self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate) + self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate) # CSC params - self.csc_shield_iterations = 100 + self.csc_shield_iterations = args.csc_shield_iterations + self.csc_update_iterations = args.csc_update_iterations self.csc_alpha = args.csc_alpha self.csc_beta = args.csc_beta self.csc_delta = args.csc_delta @@ -61,6 +62,9 @@ class CSCCQLSAC(nn.Module): self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device) self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) + + # Dev args + self.fisher_inversion = args.fisher_inversion # Actor Network self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device) @@ -94,9 +98,9 @@ class CSCCQLSAC(nn.Module): def get_action(self, state, shielded=True): """ - Returns shielded actions for given state as per current policy. + Returns shielded actions for given state as per current policy. """ - state = torch.from_numpy(state).float().to(self.device).unsqueeze(0) + state = torch.from_numpy(state).to(torch.get_default_dtype()).to(self.device).unsqueeze(0) if shielded: # Repeat state, resulting shape: (shield_iterations, state_size) @@ -150,27 +154,20 @@ class CSCCQLSAC(nn.Module): return random_values - random_log_probs def learn(self, experiences): - """Updates actor, critics and entropy_alpha parameters using given batch of experience tuples. - Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state)) - Critic_loss = MSE(Q, Q_target) - Actor_loss = α * log_pi(a|s) - Q(s,a) - where: - actor_target(state) -> action - critic_target(state, action) -> Q-value + """ Updates actor, critics, safety_critics and entropy_alpha parameters using given batch of experience tuples. Params ====== - experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples - gamma (float): discount factor + experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples """ self.csc_avg_unsafe = self.stats.train_unsafe_avg states, actions, rewards, costs, next_states, dones = experiences - states = torch.from_numpy(states).float().to(self.device) - actions = torch.from_numpy(actions).float().to(self.device) - rewards = torch.from_numpy(rewards).float().to(self.device).view(-1, 1) - costs = torch.from_numpy(costs).float().to(self.device).view(-1, 1) - next_states = torch.from_numpy(next_states).float().to(self.device) - dones = torch.from_numpy(dones).float().to(self.device).view(-1, 1) + states = torch.from_numpy(states).to(torch.get_default_dtype()).to(self.device) + actions = torch.from_numpy(actions).to(torch.get_default_dtype()).to(self.device) + rewards = torch.from_numpy(rewards).to(torch.get_default_dtype()).to(self.device).view(-1, 1) + costs = torch.from_numpy(costs).to(torch.get_default_dtype()).to(self.device).view(-1, 1) + next_states = torch.from_numpy(next_states).to(torch.get_default_dtype()).to(self.device) + dones = torch.from_numpy(dones).to(torch.get_default_dtype()).to(self.device).view(-1, 1) # ---------------------------- update critic ---------------------------- # # Get predicted next-state actions and Q values from target models @@ -274,7 +271,102 @@ class CSCCQLSAC(nn.Module): clip_grad_norm_(self.safety_critic2.parameters(), self.clip_grad_param) self.safety_critic2_optimizer.step() + # ---------------------------- update actor ---------------------------- # + # Equation 6 of the paper + # Estimate reward advantage + q1 = self.critic1(states, actions) + q2 = self.critic2(states, actions) + v = torch.min(q1, q2).detach() + + new_action, new_log_pi = self.actor_local.evaluate(states) + q1 = self.critic1(states, new_action) + q2 = self.critic2(states, new_action) + q = torch.min(q1, q2) + + reward_advantage = q - v + + # Estimate cost advantage + q1 = self.safety_critic1(states, actions) + q2 = self.safety_critic2(states, actions) + v = torch.min(q1, q2).detach() + + q1 = self.safety_critic1(states, new_action) + q2 = self.safety_critic2(states, new_action) + q = torch.min(q1, q2) + + cost_advantage = (q - v).mean() + + # Calculate losses + reward_loss = ((reward_advantage - self.alpha * new_log_pi)).mean() + cost_loss = self.csc_lambda.detach() / (1 - self.gamma) * cost_advantage + total_loss = -(reward_loss - cost_loss) + + # Calculate objective gradients + self.actor_optimizer.zero_grad() + total_loss.backward(retain_graph=True) + objective_grads = self.flatten_gradients(self.actor_local).reshape((-1, 1)) + + # Calculate fisher matrix + n = objective_grads.numel() + b = new_log_pi.shape[0] + fisher = torch.zeros((n,n), dtype=torch.get_default_dtype()).to(self.device) + for i in range(b): + self.actor_optimizer.zero_grad() + new_log_pi[i].backward(retain_graph=bool(i+1<b)) + g = self.flatten_gradients(self.actor_local).reshape((-1, 1)) + fisher += g @ g.T + fisher /= n + + # Calculate update gradient + if self.fisher_inversion == "pinv": + update_grads = torch.linalg.lstsq(fisher.cpu(), objective_grads.cpu()).solution.flatten().to(self.device) + elif self.fisher_inversion == "inv": + fisher += 1e-5 * torch.eye(n).to(self.device) # add some small constant to make it non singular + update_grads = torch.linalg.solve(fisher, objective_grads).flatten() + elif self.fisher_inversion == "none": + update_grads = objective_grads.flatten() + + # Calculate beta terms + beta_term = torch.sqrt((2*self.csc_delta) / (objective_grads.T @ fisher @ objective_grads)) + beta_j = self.csc_beta + + # Store old parameters + actor_backup = copy.deepcopy(self.actor_local) + actor_backup.load_state_dict(self.actor_local.state_dict()) + + for j in range(self.csc_update_iterations): + beta_j *= (1 - beta_j)**j + beta = (beta_j * beta_term).item() + + # Apply gradient with beta as learning rate + self.apply_flattened_gradient( + model=self.actor_local, + grads=update_grads, + lr=beta + ) + + # Calculate log_probs and kl_div for new actions + _, next_log_pi = self.actor_local.evaluate(states) + kl_div = F.kl_div(new_log_pi, next_log_pi, log_target=True, reduction='mean').item() # in our case same as batchmean + + if kl_div <= self.csc_delta: + break + else: + # Restore original parameters + self.hard_update( + local_model=actor_backup, + target_model=self.actor_local, + ) + + # Update alpha + alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + self.alpha = self.log_alpha.exp().detach() + # ---------------------------- update csc lambda ---------------------------- # + # Equation 6 of the paper # Estimate cost advantage with torch.no_grad(): q1 = self.safety_critic1(states, actions) @@ -294,42 +386,17 @@ class CSCCQLSAC(nn.Module): self.csc_lambda_optimizer.zero_grad() csc_lambda_loss.backward() self.csc_lambda_optimizer.step() - - # ---------------------------- update actor ---------------------------- # - # Estimate reward advantage - q1 = self.critic1(states, actions) - q2 = self.critic2(states, actions) - v = torch.min(q1, q2).detach() - - new_action, new_log_pi = self.actor_local.evaluate(states) - q1 = self.critic1(states, new_action) - q2 = self.critic2(states, new_action) - q = torch.min(q1, q2) - - reward_advantage = q - v - - # Optimize actor - actor_loss = ((self.alpha * new_log_pi - reward_advantage)).mean() - self.actor_optimizer.zero_grad() - actor_loss.backward() - self.actor_optimizer.step() - - # Compute alpha loss - alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean() - self.alpha_optimizer.zero_grad() - alpha_loss.backward() - self.alpha_optimizer.step() - self.alpha = self.log_alpha.exp().detach() + self.csc_lambda.data.clamp_(min=0.0) # ----------------------- update target networks ----------------------- # - self.soft_update(self.critic1, self.critic1_target) - self.soft_update(self.critic2, self.critic2_target) - self.soft_update(self.safety_critic1, self.safety_critic1_target) - self.soft_update(self.safety_critic2, self.safety_critic2_target) + self.soft_update(self.critic1, self.critic1_target, tau=self.tau) + self.soft_update(self.critic2, self.critic2_target, tau=self.tau) + self.soft_update(self.safety_critic1, self.safety_critic1_target, tau=self.tau) + self.soft_update(self.safety_critic2, self.safety_critic2_target, tau=self.tau) # ----------------------- update stats ----------------------- # data = { - "actor_loss": actor_loss.item(), + "actor_loss": total_loss.item(), "alpha_loss": alpha_loss.item(), "alpha": self.alpha.item(), "lambda_loss": csc_lambda_loss.item(), @@ -341,14 +408,19 @@ class CSCCQLSAC(nn.Module): "total_c1_loss": total_c1_loss.item(), "total_c2_loss": total_c2_loss.item(), "cql_alpha_loss": cql_alpha_loss.item(), - "cql_alpha": cql_alpha.item() + "cql_alpha": cql_alpha.item(), + "csc_update_iterations": j, + "csc_kl_div": kl_div, } - self.stats.total_updates += 1 - if (self.stats.total_updates - 1) % 8 == 0: + if self.stats.total_updates % 8 == 0: self.stats.log_update_tensorboard(data) + self.stats.total_updates += 1 return data - def soft_update(self, local_model , target_model): + def hard_update(self, local_model , target_model): + self.soft_update(local_model, target_model, tau=1.0) + + def soft_update(self, local_model , target_model, tau=0.005): """Soft update model parameters. θ_target = τ*θ_local + (1 - τ)*θ_target Params @@ -358,4 +430,26 @@ class CSCCQLSAC(nn.Module): tau (float): interpolation parameter """ for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): - target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data) + target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) + + def apply_flattened_gradient(self, model, grads, lr): + idx = 0 + for name, param in model.named_parameters(): + # Check if this param requires updating the gradient + if not param.requires_grad: continue + if param.grad is None: continue + # Extract gradient and reshape + n = param.numel() + g = grads[idx:idx+n].reshape(param.shape) + # Update parameter and idx + param.data.copy_(param.data + lr*g) + idx += n + assert(idx == grads.numel()) + + def flatten_gradients(self, model): + grads = [] + for name, param in model.named_parameters(): + if not param.requires_grad: continue + if param.grad is None: continue + grads.append(param.grad.detach().clone().flatten()) + return torch.cat(grads) \ No newline at end of file -- GitLab