From c2058916bb87617314ed0c01223abd320b8bab0c Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Thu, 20 Feb 2025 14:39:57 +0100
Subject: [PATCH] Temporary helper, policy, networks commit

---
 src/helper.py   |  27 ++++
 src/networks.py |  34 +++--
 src/policy.py   | 350 ++++++++++++++++++++++++++++--------------------
 3 files changed, 255 insertions(+), 156 deletions(-)
 create mode 100644 src/helper.py

diff --git a/src/helper.py b/src/helper.py
new file mode 100644
index 0000000..05ea49a
--- /dev/null
+++ b/src/helper.py
@@ -0,0 +1,27 @@
+import torch
+from torch.nn import Module
+
+@torch.no_grad
+def soft_update(target:Module, source:Module, tau:float):
+    """
+    Performs a soft parameter update of the target network.
+    """
+    for tparam, sparam in zip(target.parameters(), source.parameters()):
+        tparam.data.copy_(tau * sparam.data + (1.0 - tau) * tparam.data)
+
+@torch.no_grad
+def hard_update(target:Module, source:Module):
+    """
+    Performs a hard parameter update of the target network. Copies the source parameter's data into target.
+    """
+    soft_update(target, source, tau=1.0)
+
+@torch.no_grad
+def apply_gradient(network:Module, lr):
+    """
+    Uses model gradients and a learning rate to update all model parameters.
+    """
+    for name, param in network.named_parameters():
+        if not param.requires_grad: continue
+        if param.grad is None: continue
+        param += lr * param.grad
\ No newline at end of file
diff --git a/src/networks.py b/src/networks.py
index 0ba8102..5d73c34 100644
--- a/src/networks.py
+++ b/src/networks.py
@@ -3,12 +3,9 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.distributions import Normal
 
-"""
-Copied from spice project.
-"""
-
 LOG_SIG_MAX = 2
 LOG_SIG_MIN = -20
+epsilon = 1e-6
 
 # Initialize Policy weights
 def weights_init_(m):
@@ -85,14 +82,35 @@ class GaussianPolicy(nn.Module):
         return mean, log_std
     
     def distribution(self, state):
-        mean, log_std = self.forward(state)
-        std = log_std.exp()
-        normal = Normal(mean, std)
+        try:
+            mean, log_std = self.forward(state)
+            std = log_std.exp()
+            normal = Normal(mean, std)
+        except ValueError:
+            print("state:", state)
+            print(mean)
+            print(log_std)
+            print(std)
+
+            print("linear1:", self.linear1)
+            print(self.linear2)
+            print(self.mean_linear)
+            print(self.log_std_linear)
+
+            print("bias:", self.action_bias)
+            print(self.action_scale)
+            exit(0)
+            
         return normal
     
     def log_prob(self, state, action):
         dist = self.distribution(state)
-        return dist.log_prob(action)
+        y_t = (action - self.action_bias) / self.action_scale
+        x_t = torch.atanh(y_t)
+        log_prob = dist.log_prob(x_t)
+        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
+        log_prob = log_prob.sum(axis=-1, keepdim=True)
+        return log_prob
 
     def sample(self, state, num_samples=1):
         normal = self.distribution(state)
diff --git a/src/policy.py b/src/policy.py
index 005d25f..ae73543 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -1,10 +1,13 @@
 import torch
 from torch.distributions import kl_divergence
-from .networks import GaussianPolicy, QNetwork, ValueNetwork
+from torch.nn.functional import mse_loss
+from src.networks import GaussianPolicy, QNetwork, ValueNetwork
+from src.helper import soft_update, hard_update, apply_gradient
 
 class CSCAgent():
-    def __init__(self, env, args, writer) -> None:
-        self._writer = writer
+    def __init__(self, env, args, buffer, stats) -> None:
+        self._buffer = buffer
+        self._stats = stats
 
         self._device = args.device
         self._shield_iterations = args.shield_iterations
@@ -15,7 +18,7 @@ class CSCAgent():
         self._gamma = args.csc_gamma
         self._delta = args.csc_delta
         self._chi = args.csc_chi
-        self._avg_failures = self._chi
+        self._avg_train_unsafe = self._chi
         self._batch_size = args.batch_size
         self._expectation_estimation_samples = args.expectation_estimation_samples
         self._tau = args.tau
@@ -28,219 +31,270 @@ class CSCAgent():
         self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
         self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
         self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
-        self._lambda = torch.tensor(args.csc_lambda, requires_grad=True, device=self._device)
+        self._lambda = torch.nn.Parameter(torch.tensor(args.csc_lambda, requires_grad=True, device=self._device))
 
         self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
         self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
         self._target_value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
-        self.soft_update(self._policy_old, self._policy, tau=1)
-        self.soft_update(self._target_safety_critic, self._safety_critic, tau=1)
-        self.soft_update(self._target_value_network, self._value_network, tau=1)
+        hard_update(self._policy_old, self._policy)
+        hard_update(self._target_safety_critic, self._safety_critic)
+        hard_update(self._target_value_network, self._value_network)
 
         self._optim_safety_critic = torch.optim.Adam(self._safety_critic.parameters(), lr=args.csc_safety_critic_lr)
         self._optim_value_network = torch.optim.Adam(self._value_network.parameters(), lr=args.csc_value_network_lr)
 
-        self._policy.eval()
-        self._safety_critic.eval()
-        self._value_network.eval()
-
-
-    @staticmethod
-    @torch.no_grad
-    def soft_update(target, source, tau):
-        for tparam, sparam in zip(target.parameters(), source.parameters()):
-            tparam.copy_((1 - tau) * tparam.data + tau * sparam.data)
-    
-    @staticmethod
-    @torch.no_grad
-    def apply_gradient(network:torch.nn.Module, gradient):
-        l = 0
-        for name, param in network.named_parameters():
-            if not param.requires_grad: continue
-            n = param.numel()
-            param.copy_(param.data + gradient[l:l+n].reshape(param.shape))
-            l += n
-    
-    @staticmethod
-    @torch.no_grad
-    def collect_gradient(network:torch.nn.Module):
-        grads = []
-        for name, param in network.named_parameters():
-            if not param.requires_grad: continue
-            g = torch.zeros_like(param)
-            if not param.grad is None: g += param.grad
-            grads.append(g.flatten())
-        return torch.hstack(grads)
-
-
-    def _cost_advantage(self, states, actions):
-        cost_actions = self._safety_critic.forward(states, actions)
-        cost_states = torch.zeros_like(cost_actions)
-        for _ in range(self._expectation_estimation_samples):
-            a = self._action_for_update(policy=self._policy_old, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach()
-            cost_states += self._safety_critic.forward(states, a)
-        cost_states /= self._expectation_estimation_samples
-        return cost_actions - cost_states
-    
-    def _reward_diff(self, states, rewards, next_states):
+        # TODO
+        self._optim_policy = torch.optim.Adam(self._policy.parameters(), lr=args.csc_safety_critic_lr)
+
+    def _update_value_network(self, states, rewards, next_states, dones):
+        """
+        Updates the reward value network using MSE on a batch of experiences.
+        """
         value_states = self._value_network.forward(states)
         with torch.no_grad():
             value_next_states = self._target_value_network.forward(next_states)
+            value_next_states *= (1 - dones.view((-1, 1)))
             value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
-        return value_target - value_states
-
-
-    def _update_value_network(self, states, rewards, next_states):
-        value_diff = self._reward_diff(states, rewards, next_states)
-        loss = 1/2 * torch.square(value_diff).mean()
+        loss = mse_loss(value_states, value_target, reduction='mean')
 
         self._optim_value_network.zero_grad()
         loss.backward()
         self._optim_value_network.step()
         return loss.item()
 
-    def _update_safety_critic(self, states, actions, costs, next_states):
+
+    def _update_safety_critic(self, states, actions, costs, next_states, dones):
+        """
+        Updates the safety critic network using the CQL objective (Equation 2) from the CSC paper.
+        """
+        # states, action from old policy (from buffer)
         safety_sa_env = self._safety_critic.forward(states, actions)
 
-        a = self._action_for_update(policy=self._policy, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach()
-        safety_s_env_a_p = self._safety_critic.forward(states, a)
+        # states from old policy (from buffer), actions from current policy
+        with torch.no_grad():
+            actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling)
+        safety_s_env_a_p = self._safety_critic.forward(states, actions_current_policy)
+
+        # first loss term (steers critic towards overapproximation)
+        loss1 = (-safety_s_env_a_p.mean() + safety_sa_env.mean())
 
+        # bellman operator
         with torch.no_grad():
             safety_next_state = torch.zeros_like(safety_sa_env)
             for _ in range(self._expectation_estimation_samples):
-                a = self._action_for_update(policy=self._policy, state=next_states, shielded=self._shielded_action_sampling).squeeze(0).detach()
-                safety_next_state += self._target_safety_critic.forward(next_states, a)
+                actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling)
+                safety_next_state += self._target_safety_critic.forward(next_states, actions_current_policy)
             safety_next_state /= self._expectation_estimation_samples
-            safety_next_state = costs.view((-1, 1)) + self._gamma * safety_next_state
-        safety_sasc_env = torch.square(safety_sa_env - safety_next_state)
+            safety_next_state *= (1 - dones.view((-1, 1)))
+            safety_target = costs.view((-1, 1)) + self._gamma * safety_next_state
+        
+        # second loss term (mse loss)
+        loss2 = mse_loss(safety_sa_env, safety_target, reduction='mean')
 
-        loss = self._alpha * (safety_sa_env.mean() - safety_s_env_a_p.mean()) + 1/2 * safety_sasc_env.mean()
+        # overall weighted loss as sum of loss1 and loss2
+        loss = self._alpha * loss1 + 1/2 * loss2
 
         self._optim_safety_critic.zero_grad()
         loss.backward()
         self._optim_safety_critic.step()
         return loss.item()
 
-    def _primal_dual_gradient(self, states, actions, rewards, next_states, total_episodes):
-        # Equation 45
-        reward_advantage = self._reward_diff(states, rewards, next_states).detach()
+    def _tmp(self, states, actions, rewards, next_states, dones):
+        # importance sampling weights
+        action_log_prob = self._policy.log_prob(states, actions)
+        action_log_prob_old = self._policy_old.log_prob(states, actions).detach()
+        # avoid underflow by subtracting the max
+        max_log_prob = torch.max(torch.stack((action_log_prob, action_log_prob_old)))
+        action_log_prob -= max_log_prob
+        action_log_prob_old -= max_log_prob
+        # calculate importance ratio
+        weighting:torch.Tensor = (action_log_prob - action_log_prob_old).exp()
+
+        # reward advantage
+        with torch.no_grad():
+            value_states = self._target_value_network.forward(states)
+            value_next_states = self._target_value_network.forward(next_states)
+            value_next_states *= (1 - dones.view((-1, 1)))
+            value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
+            reward_advantage = value_target - value_states
+        objective = (weighting * reward_advantage).mean()
+        # objective = (weighting * value_states).mean()
+
+        #objective *= -1
+        #self._optim_policy.zero_grad()
+        #objective.backward()
+        self._policy.zero_grad()
+        objective.backward()
+        if False:
+            print("alp:", action_log_prob)
+            print(action_log_prob_old)
+            print("weighting:", weighting)
+            print(objective)
+            print("lin1grad:", self._policy.linear1.weight.grad)
+            print(self._policy.linear2.weight.grad)
+            print("meanlingrad:", self._policy.mean_linear.weight.grad)
+            print(self._policy.log_std_linear.weight.grad)
+        apply_gradient(self._policy, lr=self._lambda_lr)
+        #self._optim_policy.step()
+
+        return objective.item()
+
+
+    def _primal_dual_gradient(self, states, actions, rewards, next_states, dones):
+        """
+        Updates the policy and lambda parameter. Calculates the objective from equation 45 from the CSC paper.
+        """
+        return self._tmp(states, actions, rewards, next_states, dones)
+        # importance sampling factor
+        action_log_prob = self._policy.log_prob(states, actions)
+        action_log_prob_old = self._policy_old.log_prob(states, actions).detach()
+        weighting = (action_log_prob - action_log_prob_old).exp()
+
+        # reward advantage
+        with torch.no_grad():
+            value_states = self._target_value_network.forward(states)
+            value_next_states = self._target_value_network.forward(next_states)
+            value_next_states *= (1 - dones.view((-1, 1)))
+            value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
+            reward_advantage = value_target - value_states
+
+        # lambda prime
         lambda_prime = self._lambda / (1 - self._gamma)
-        chi_prime_term = self._lambda * (self._chi - self._avg_failures)
-        cost_advantage = self._cost_advantage(states, actions).detach()
-        a,b = self._policy.log_prob(states, actions), self._policy_old.log_prob(states, actions).detach()
-        ratio = (a - b).exp()
 
-        objective = (ratio * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term
+        # cost advantage
+        with torch.no_grad():
+            cost_actions = self._target_safety_critic.forward(states, actions)
+            cost_states = torch.zeros_like(cost_actions)
+            for _ in range(self._expectation_estimation_samples):
+                actions_policy_old = self._sample_policy_old(states, shielded=self._shielded_action_sampling)
+                cost_states += self._target_safety_critic.forward(states, actions_policy_old)
+            cost_states /= self._expectation_estimation_samples
+            cost_advantage = cost_actions - cost_states
+        
+        # chi prime term (eq. 44)
+        chi_prime_term = self._lambda * (self._chi - self._avg_train_unsafe)
+
+        # overall objective
+        objective = (weighting * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term
 
+        # reset gradients
         self._policy.zero_grad()
         self._lambda.grad = None
+
+        # calculate gradients
         objective.backward()
 
-        update_lambda = self._lambda_lr * self._lambda.grad
-        with torch.no_grad():
-            self._lambda.copy_(self._lambda.data + update_lambda)  # descent
+        # apply policy gradient
+        apply_gradient(self._policy, lr=1)                  # gradient ascent
 
-        self._writer.add_scalar(f"debug/objective", round(objective.item(),4), total_episodes)
-        self._writer.add_scalar(f"debug/lambda", round(self._lambda.item(),4), total_episodes)
+        # obtain action distributions
+        action_dist_old = self._policy_old.distribution(states)
+        action_dist_new = self._policy.distribution(states)
+        kl_div = kl_divergence(action_dist_old, action_dist_new).mean()
 
-        dist_old = self._policy_old.distribution(states)
-        gradient = self.collect_gradient(self._policy)
-        self.apply_gradient(self._policy, gradient)     # ascent
-        dist_new = self._policy.distribution(states)
-        self.apply_gradient(self._policy, -gradient)    # descent back
-        kl_div = kl_divergence(dist_old, dist_new).mean()
+        # revert update
+        apply_gradient(self._policy, lr=-1)                 # gradient descent back
 
+        # calculate beta
         beta_term = torch.sqrt(self._delta / kl_div)
         beta_j = self._beta
+
+        # line search
         for j in range(self._line_search_iterations):
             beta_j = beta_j * (1 - beta_j)**j
             beta = beta_j * beta_term
 
-            update_policy = beta * gradient
-            self.apply_gradient(self._policy, update_policy)      # ascent
-            dist_new = self._policy.distribution(states)
-            kl_div = kl_divergence(dist_old, dist_new).mean()
+            apply_gradient(self._policy, lr=beta)         # gradient ascent
+            action_dist_new = self._policy.distribution(states)
+            kl_div = kl_divergence(action_dist_old, action_dist_new).mean()
 
-            if kl_div <= self._delta:
-                return j, update_lambda.item()
+            if kl_div > self._delta:
+                apply_gradient(self._policy, lr=-beta)     # gradient descent back
             else:
-                self.apply_gradient(self._policy, -update_policy)     # descent back
-
-        return torch.nan, update_lambda.item()
+                break
 
+        # update lambda
+        with torch.no_grad():
+            self._lambda -= self._lambda_lr * self._lambda.grad   # gradient descent
+            self._lambda.clamp_min_(min=0)
+        
+        return j
 
-    def update(self, buffer, avg_failures, total_episodes):
-        self._avg_failures = avg_failures
-        self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures)
 
-        states, actions, rewards, costs, next_states = buffer.sample(self._batch_size)
+    def update(self):
+        """
+        Performs one iteration of updates. Updates the value network, policy network, lambda parameter and safety critic.
+        """
+        self._avg_train_unsafe = self._stats.avg_train_unsafe
+        
+        states, actions, rewards, costs, next_states, dones = self._buffer.sample(self._batch_size)
         states = torch.tensor(states, device=self._device)
         actions = torch.tensor(actions, device=self._device)
         rewards = torch.tensor(rewards, device=self._device)
         costs = torch.tensor(costs, device=self._device)
         next_states = torch.tensor(next_states, device=self._device)
+        dones = torch.tensor(dones, device=self._device)
 
-        vloss = self._update_value_network(states, rewards, next_states)
-        sloss = self._update_safety_critic(states, actions, costs, next_states)
-        piter, lgradient = self._primal_dual_gradient(states, actions, rewards, next_states, total_episodes)
+        vloss = self._update_value_network(states, rewards, next_states, dones)
+        soft_update(self._target_value_network, self._value_network, tau=self._tau)
+        piter = self._primal_dual_gradient(states, actions, rewards, next_states, dones)
+        sloss = self._update_safety_critic(states, actions, costs, next_states, dones)
+        soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau)
 
-        self.soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau)
-        self.soft_update(self._target_value_network, self._value_network, tau=self._tau)
+        self._stats.writer.add_scalar("debug/vloss", vloss, self._stats.total_train_episodes)
+        self._stats.writer.add_scalar("debug/sloss", sloss, self._stats.total_train_episodes)
+        self._stats.writer.add_scalar("debug/piter", piter, self._stats.total_train_episodes)
+        self._stats.writer.add_scalar("debug/lambda", self._lambda, self._stats.total_train_episodes)
 
-        self._writer.add_scalar(f"agent/policy_iterations", piter, total_episodes)
-        self._writer.add_scalar(f"agent/value_loss", round(vloss,4), total_episodes)
-        self._writer.add_scalar(f"agent/safety_loss", round(sloss,4), total_episodes)
-        self._writer.add_scalar(f"agent/lambda_gradient", round(lgradient,4), total_episodes)
 
     def after_updates(self):
-        self.soft_update(self._policy_old, self._policy, tau=1)
+        self._unsafety_threshold = (1 - self._gamma) * (self._chi - self._avg_train_unsafe)
+        hard_update(self._policy_old, self._policy)
 
-    @torch.no_grad
-    def sample(self, state, shielded=True):
-        state = torch.tensor(state, device=self._device)
-        if shielded:
-            action = self._policy.sample(state, num_samples=self._shield_iterations)
-
-            state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1))
-            unsafety = self._safety_critic.forward(state, action).squeeze(2).permute(1,0)
-
-            mask = unsafety <= self._unsafety_threshold
-            argmax = mask.int().argmax(dim=1)   # col idx of first "potentially" safe action
-            is_zero = mask[torch.arange(0, state.shape[1]), argmax] == 0    # check if in a row every col is unsafe
 
-            action = action.permute(1, 0, 2)
-            result = torch.zeros((state.shape[1], action.shape[-1]), device=self._device)
-            result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...]     # minimum in case only unsafe actions
-            result[~is_zero] = action[~is_zero, argmax[~is_zero], ...]    # else first safe action
+    def _sample_policy_old(self, state, shielded):
+        """
+        Allows action sampling from policy_old.
+        """
+        tmp = self._policy
+        self._policy = self._policy_old
+        actions = self.sample(state, shielded=shielded)
+        self._policy = tmp
+        return actions
 
-            return result.cpu().numpy()
-        
-        else:
-            action = self._policy.sample(state)
-            return action.squeeze(0).cpu().numpy()
-
-    @torch.no_grad
-    def _action_for_update(self, policy, state, shielded=True, return_dist=False):
+    def sample(self, state, shielded=True):
+        """
+        Samples and returns one action for every state. If shielded is true, performs rejection sampling according to the CSC paper using the safety critic.
+        Instead of using a loop, we sample all actions simultaneously and pick:
+        - the first with an estimated unsafety level <= self._unsafety_threshold (epsilon in the CSC paper)
+        - or else, the one that achieves maximum safety, i.e., lowest estimated unsafety
+        """
+        state = torch.tensor(state, device=self._device, dtype=torch.float64)
         if shielded:
-            dist = policy.distribution(state)
-            action = dist.sample((self._shield_iterations,))
+            # sample all actions, expand/copy state, estimate safety for all actions at the same time
+            actions = self._policy.sample(state, num_samples=self._shield_iterations) # shape: (shield_iterations, batch_size, action_size)
+            state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1)) # shape: (shield_iterations, batch_size, state_size)
+            unsafety = self._safety_critic.forward(state, actions) # shape: (shield_iterations, batch_size, 1)
+            unsafety = unsafety.squeeze(2) # shape: (shield_iterations, batch_size)
 
-            state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1))
-            unsafety = self._safety_critic.forward(state, action).squeeze(2).permute(1,0)
+            # check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state
+            mask = unsafety <= self._unsafety_threshold # shape: (shield_iterations, batch_size)
+            row_idx = mask.int().argmax(dim=0) # idx of first "potentially" safe actions
+            batch_idx = torch.arange(0, mask.shape[1])
 
-            mask = unsafety <= self._unsafety_threshold
-            argmax = mask.int().argmax(dim=1)   # col idx of first "potentially" safe action
-            is_zero = mask[torch.arange(0, state.shape[1]), argmax] == 0    # check if in a row every col is unsafe
+            # retrieve estimated safety of said action and check if it is indeed <= threshold
+            # if for a state all actions are unsafe (> threshold), argmax will return the first unsafe as mask[state] == 0 everywhere
+            unsafety_chosen = unsafety[row_idx, batch_idx]
+            all_unsafe = unsafety_chosen > self._unsafety_threshold
 
-            action = action.permute(1, 0, 2)
-            result = torch.zeros((state.shape[1], action.shape[-1]), device=self._device)
-            result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...]     # minimum in case only unsafe actions
-            result[~is_zero] = action[~is_zero, argmax[~is_zero], ...]    # else first safe action
+            # for such states retrieve the action with a minimum level of unsafety
+            row_idx[all_unsafe] = unsafety[:, all_unsafe].argmin(dim=0)
+
+            # sample and return actions according to idxs
+            result = actions[row_idx, batch_idx, :]
+            return result
 
         else:
-            dist = policy.distribution(state)
-            result = dist.sample().squeeze(0)
-        
-        if return_dist:
-            return result, dist
-        return result
\ No newline at end of file
+            # simply sample for every state one action
+            actions = self._policy.sample(state) # shape: (num_samples, batch_size, action_size)
+            return actions.squeeze(0)
\ No newline at end of file
-- 
GitLab