From 937cac7d5b14ff1cb300fcdbd608db3777046fd7 Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Sat, 8 Mar 2025 00:17:35 +0100
Subject: [PATCH] Added csc actor update

---
 main.py              |  10 +++
 src/cql_sac/agent.py | 206 +++++++++++++++++++++++++++++++------------
 2 files changed, 160 insertions(+), 56 deletions(-)

diff --git a/main.py b/main.py
index bff7053..496acce 100644
--- a/main.py
+++ b/main.py
@@ -85,6 +85,10 @@ def cmd_args():
                         help="Set the value of alpha")
     csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N",
                         help="Set the initial value of lambda")
+    csc_args.add_argument("--csc_shield_iterations", action="store", type=int, default=100, metavar="N",
+                        help="Set the number of sampled actions during shielding")
+    csc_args.add_argument("--csc_update_iterations", action="store", type=int, default=20, metavar="N",
+                        help="Set the number of linear backtracking iterations")
 
     # common args
     common_args = parser.add_argument_group('Common')
@@ -96,6 +100,11 @@ def cmd_args():
                         help="Set the output and log directory path")
     common_args.add_argument("--num_threads", action="store", type=int, default=1, metavar="N",
                         help="Set the maximum number of threads for pytorch and numpy")
+    
+    # dev args
+    dev_args = parser.add_argument_group('Development')
+    dev_args.add_argument("--fisher_inversion", action="store", type=str, default="pinv", metavar="INV",
+                        help="Inversion mode for fisher matrix: pinv, inv, none")
 
     args = parser.parse_args()
     return args
@@ -109,6 +118,7 @@ def setup(args):
     Performs setup like fixing seeds, initializing env and agent, buffer and stats.
     """
     torch.set_num_threads(args.num_threads)
+    torch.set_default_dtype(torch.double)
     random.seed(args.seed)
     np.random.seed(args.seed)
     torch.manual_seed(args.seed)
diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py
index 8ed7054..827c28d 100644
--- a/src/cql_sac/agent.py
+++ b/src/cql_sac/agent.py
@@ -5,7 +5,7 @@ import torch.nn as nn
 from torch.nn.utils import clip_grad_norm_
 from .networks import Critic, Actor
 import math
-
+import copy
 
 class CSCCQLSAC(nn.Module):
     """Interacts with and learns from the environment."""
@@ -49,10 +49,11 @@ class CSCCQLSAC(nn.Module):
         self.cql_weight = args.cql_weight
         self.cql_target_action_gap = args.cql_target_action_gap
         self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
-        self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate) 
+        self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate)
 
         # CSC params
-        self.csc_shield_iterations = 100
+        self.csc_shield_iterations = args.csc_shield_iterations
+        self.csc_update_iterations = args.csc_update_iterations
         self.csc_alpha = args.csc_alpha
         self.csc_beta = args.csc_beta
         self.csc_delta = args.csc_delta
@@ -61,6 +62,9 @@ class CSCCQLSAC(nn.Module):
 
         self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device)
         self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate)
+
+        # Dev args
+        self.fisher_inversion = args.fisher_inversion
         
         # Actor Network 
         self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device)
@@ -94,9 +98,9 @@ class CSCCQLSAC(nn.Module):
 
     def get_action(self, state, shielded=True):
         """ 
-        Returns shielded actions for given state as per current policy. 
+        Returns shielded actions for given state as per current policy.
         """
-        state = torch.from_numpy(state).float().to(self.device).unsqueeze(0)
+        state = torch.from_numpy(state).to(torch.get_default_dtype()).to(self.device).unsqueeze(0)
         
         if shielded:
             # Repeat state, resulting shape: (shield_iterations, state_size)
@@ -150,27 +154,20 @@ class CSCCQLSAC(nn.Module):
         return random_values - random_log_probs
     
     def learn(self, experiences):
-        """Updates actor, critics and entropy_alpha parameters using given batch of experience tuples.
-        Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state))
-        Critic_loss = MSE(Q, Q_target)
-        Actor_loss = α * log_pi(a|s) - Q(s,a)
-        where:
-            actor_target(state) -> action
-            critic_target(state, action) -> Q-value
+        """ Updates actor, critics, safety_critics and entropy_alpha parameters using given batch of experience tuples.
         Params
         ======
-            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples 
-            gamma (float): discount factor
+            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples
         """
         self.csc_avg_unsafe = self.stats.train_unsafe_avg
         states, actions, rewards, costs, next_states, dones = experiences
 
-        states = torch.from_numpy(states).float().to(self.device)
-        actions = torch.from_numpy(actions).float().to(self.device)
-        rewards = torch.from_numpy(rewards).float().to(self.device).view(-1, 1)
-        costs = torch.from_numpy(costs).float().to(self.device).view(-1, 1)
-        next_states = torch.from_numpy(next_states).float().to(self.device)
-        dones = torch.from_numpy(dones).float().to(self.device).view(-1, 1)
+        states = torch.from_numpy(states).to(torch.get_default_dtype()).to(self.device)
+        actions = torch.from_numpy(actions).to(torch.get_default_dtype()).to(self.device)
+        rewards = torch.from_numpy(rewards).to(torch.get_default_dtype()).to(self.device).view(-1, 1)
+        costs = torch.from_numpy(costs).to(torch.get_default_dtype()).to(self.device).view(-1, 1)
+        next_states = torch.from_numpy(next_states).to(torch.get_default_dtype()).to(self.device)
+        dones = torch.from_numpy(dones).to(torch.get_default_dtype()).to(self.device).view(-1, 1)
 
         # ---------------------------- update critic ---------------------------- #
         # Get predicted next-state actions and Q values from target models
@@ -274,7 +271,102 @@ class CSCCQLSAC(nn.Module):
         clip_grad_norm_(self.safety_critic2.parameters(), self.clip_grad_param)
         self.safety_critic2_optimizer.step()
 
+        # ---------------------------- update actor ---------------------------- #
+        # Equation 6 of the paper
+        # Estimate reward advantage
+        q1 = self.critic1(states, actions)
+        q2 = self.critic2(states, actions)
+        v = torch.min(q1, q2).detach()
+
+        new_action, new_log_pi = self.actor_local.evaluate(states)
+        q1 = self.critic1(states, new_action)
+        q2 = self.critic2(states, new_action)
+        q = torch.min(q1, q2)
+
+        reward_advantage = q - v
+
+        # Estimate cost advantage
+        q1 = self.safety_critic1(states, actions)
+        q2 = self.safety_critic2(states, actions)
+        v = torch.min(q1, q2).detach()
+
+        q1 = self.safety_critic1(states, new_action)
+        q2 = self.safety_critic2(states, new_action)
+        q = torch.min(q1, q2)
+
+        cost_advantage = (q - v).mean()
+
+        # Calculate losses
+        reward_loss = ((reward_advantage - self.alpha * new_log_pi)).mean()
+        cost_loss = self.csc_lambda.detach() / (1 - self.gamma) * cost_advantage
+        total_loss = -(reward_loss - cost_loss)
+        
+        # Calculate objective gradients
+        self.actor_optimizer.zero_grad()
+        total_loss.backward(retain_graph=True)
+        objective_grads = self.flatten_gradients(self.actor_local).reshape((-1, 1))
+
+        # Calculate fisher matrix
+        n = objective_grads.numel()
+        b = new_log_pi.shape[0]
+        fisher = torch.zeros((n,n), dtype=torch.get_default_dtype()).to(self.device)
+        for i in range(b):
+            self.actor_optimizer.zero_grad()
+            new_log_pi[i].backward(retain_graph=bool(i+1<b))
+            g = self.flatten_gradients(self.actor_local).reshape((-1, 1))
+            fisher += g @ g.T
+        fisher /= n
+
+        # Calculate update gradient
+        if self.fisher_inversion == "pinv":
+            update_grads = torch.linalg.lstsq(fisher.cpu(), objective_grads.cpu()).solution.flatten().to(self.device)
+        elif self.fisher_inversion == "inv":
+            fisher += 1e-5 * torch.eye(n).to(self.device)   # add some small constant to make it non singular
+            update_grads = torch.linalg.solve(fisher, objective_grads).flatten()
+        elif self.fisher_inversion == "none":
+            update_grads = objective_grads.flatten()
+
+        # Calculate beta terms
+        beta_term = torch.sqrt((2*self.csc_delta) / (objective_grads.T @ fisher @ objective_grads))
+        beta_j = self.csc_beta
+
+        # Store old parameters
+        actor_backup = copy.deepcopy(self.actor_local)
+        actor_backup.load_state_dict(self.actor_local.state_dict())
+
+        for j in range(self.csc_update_iterations):
+            beta_j *= (1 - beta_j)**j
+            beta = (beta_j * beta_term).item()
+            
+            # Apply gradient with beta as learning rate
+            self.apply_flattened_gradient(
+                model=self.actor_local,
+                grads=update_grads,
+                lr=beta
+            )
+
+            # Calculate log_probs and kl_div for new actions
+            _, next_log_pi = self.actor_local.evaluate(states)
+            kl_div = F.kl_div(new_log_pi, next_log_pi, log_target=True, reduction='mean').item()    # in our case same as batchmean
+
+            if kl_div <= self.csc_delta:
+                break
+            else:
+                # Restore original parameters
+                self.hard_update(
+                    local_model=actor_backup,
+                    target_model=self.actor_local,
+                )
+
+        # Update alpha
+        alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean()
+        self.alpha_optimizer.zero_grad()
+        alpha_loss.backward()
+        self.alpha_optimizer.step()
+        self.alpha = self.log_alpha.exp().detach()
+
         # ---------------------------- update csc lambda ---------------------------- #
+        # Equation 6 of the paper
         # Estimate cost advantage
         with torch.no_grad():
             q1 = self.safety_critic1(states, actions)
@@ -294,42 +386,17 @@ class CSCCQLSAC(nn.Module):
         self.csc_lambda_optimizer.zero_grad()
         csc_lambda_loss.backward()
         self.csc_lambda_optimizer.step()
-
-        # ---------------------------- update actor ---------------------------- #
-        # Estimate reward advantage
-        q1 = self.critic1(states, actions)
-        q2 = self.critic2(states, actions)
-        v = torch.min(q1, q2).detach()
-
-        new_action, new_log_pi = self.actor_local.evaluate(states)
-        q1 = self.critic1(states, new_action)
-        q2 = self.critic2(states, new_action)
-        q = torch.min(q1, q2)
-
-        reward_advantage = q - v
-
-        # Optimize actor
-        actor_loss = ((self.alpha * new_log_pi - reward_advantage)).mean()
-        self.actor_optimizer.zero_grad()
-        actor_loss.backward()
-        self.actor_optimizer.step()
-        
-        # Compute alpha loss
-        alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean()
-        self.alpha_optimizer.zero_grad()
-        alpha_loss.backward()
-        self.alpha_optimizer.step()
-        self.alpha = self.log_alpha.exp().detach()
+        self.csc_lambda.data.clamp_(min=0.0)
 
         # ----------------------- update target networks ----------------------- #
-        self.soft_update(self.critic1, self.critic1_target)
-        self.soft_update(self.critic2, self.critic2_target)
-        self.soft_update(self.safety_critic1, self.safety_critic1_target)
-        self.soft_update(self.safety_critic2, self.safety_critic2_target)
+        self.soft_update(self.critic1, self.critic1_target, tau=self.tau)
+        self.soft_update(self.critic2, self.critic2_target, tau=self.tau)
+        self.soft_update(self.safety_critic1, self.safety_critic1_target, tau=self.tau)
+        self.soft_update(self.safety_critic2, self.safety_critic2_target, tau=self.tau)
         
         # ----------------------- update stats ----------------------- #
         data = {
-            "actor_loss": actor_loss.item(),
+            "actor_loss": total_loss.item(),
             "alpha_loss": alpha_loss.item(),
             "alpha": self.alpha.item(),
             "lambda_loss": csc_lambda_loss.item(),
@@ -341,14 +408,19 @@ class CSCCQLSAC(nn.Module):
             "total_c1_loss": total_c1_loss.item(),
             "total_c2_loss": total_c2_loss.item(),
             "cql_alpha_loss": cql_alpha_loss.item(),
-            "cql_alpha": cql_alpha.item()
+            "cql_alpha": cql_alpha.item(),
+            "csc_update_iterations": j,
+            "csc_kl_div": kl_div,
         }
-        self.stats.total_updates += 1
-        if (self.stats.total_updates - 1) % 8 == 0:
+        if self.stats.total_updates % 8 == 0:
             self.stats.log_update_tensorboard(data)
+        self.stats.total_updates += 1
         return data
 
-    def soft_update(self, local_model , target_model):
+    def hard_update(self, local_model , target_model):
+        self.soft_update(local_model, target_model, tau=1.0)
+
+    def soft_update(self, local_model , target_model, tau=0.005):
         """Soft update model parameters.
         θ_target = τ*θ_local + (1 - τ)*θ_target
         Params
@@ -358,4 +430,26 @@ class CSCCQLSAC(nn.Module):
             tau (float): interpolation parameter 
         """
         for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
-            target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)
+            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
+    
+    def apply_flattened_gradient(self, model, grads, lr):
+        idx = 0
+        for name, param in model.named_parameters():
+            # Check if this param requires updating the gradient
+            if not param.requires_grad: continue
+            if param.grad is None: continue
+            # Extract gradient and reshape
+            n = param.numel()
+            g = grads[idx:idx+n].reshape(param.shape)
+            # Update parameter and idx
+            param.data.copy_(param.data + lr*g)
+            idx += n
+        assert(idx == grads.numel())
+    
+    def flatten_gradients(self, model):
+        grads = []
+        for name, param in model.named_parameters():
+            if not param.requires_grad: continue
+            if param.grad is None: continue
+            grads.append(param.grad.detach().clone().flatten())
+        return torch.cat(grads)
\ No newline at end of file
-- 
GitLab