From 5408092610bc4f0f0b1d48ff95a5979d07529484 Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Mon, 30 Sep 2024 17:01:19 +0200
Subject: [PATCH] Added policy_old, some changes in update

---
 src/policy.py | 76 +++++++++++++++++++++++++--------------------------
 1 file changed, 38 insertions(+), 38 deletions(-)

diff --git a/src/policy.py b/src/policy.py
index 2ef02cd..66cd711 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -23,16 +23,18 @@ class CSCAgent():
         self._lambda = args.csc_lambda
         self._expectation_estimation_samples = args.expectation_estimation_samples
         self._tau = args.tau
+        self._hidden_dim = args.hidden_dim
 
         num_inputs = env.observation_space.shape[-1]
         num_actions = env.action_space.shape[-1]
-        hidden_dim = args.hidden_dim
-        self._policy = GaussianPolicy(num_inputs, num_actions, hidden_dim, env.action_space).to(self._device)
-        self._safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device)
-        self._value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device)
-
-        self._target_safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device)
-        self._target_value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device)
+        self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
+        self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim).to(self._device)
+        self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
+
+        self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
+        self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim).to(self._device)
+        self._target_value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
+        self.soft_update(self._policy_old, self._policy, tau=1)
         self.soft_update(self._target_safety_critic, self._safety_critic, tau=1)
         self.soft_update(self._target_value_network, self._value_network, tau=1)
 
@@ -41,7 +43,9 @@ class CSCAgent():
 
         self._policy.eval()
         self._safety_critic.eval()
-    
+        self._value_network.eval()
+
+
     @staticmethod
     @torch.no_grad
     def soft_update(target, source, tau):
@@ -53,30 +57,18 @@ class CSCAgent():
         cost_actions = self._safety_critic.forward(states, actions)
         cost_states = torch.zeros_like(cost_actions)
         for _ in range(self._expectation_estimation_samples):
-            a = self._policy.sample(states).squeeze(0).detach()
-            cost_states += self._target_safety_critic.forward(states, a)
+            a = self._policy_old.sample(states).squeeze(0).detach()
+            cost_states += self._safety_critic.forward(states, a)
         cost_states /= self._expectation_estimation_samples
         return cost_actions - cost_states
     
-
     def _reward_diff(self, states, rewards, next_states):
         value_states = self._value_network.forward(states)
         with torch.no_grad():
             value_next_states = self._target_value_network.forward(next_states)
             value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
         return value_target - value_states
-
-
-    def _update_value_network(self, states, rewards, next_states):
-        value_diff = self._reward_diff(states, rewards, next_states)
-        loss = 1/2 * torch.square(value_diff).mean()
-
-        self._optim_value_network.zero_grad()
-        loss.backward()
-        self._optim_value_network.step()
-        return loss.item()
-
-
+    
     def _update_policy(self, states, actions, rewards, next_states):
         @torch.no_grad
         def estimate_ahat(states, actions, rewards, next_states):
@@ -94,14 +86,14 @@ class CSCAgent():
                 param.copy_(param.data + gradient[l:l+n].reshape(param.shape))
                 l += n
         
-        def log_prob_grad(states, actions):
+        def log_prob_grad_batched(states, actions):
             # https://pytorch.org/tutorials/intermediate/per_sample_grads.html
-            params = {k: v.detach() for k, v in self._policy.named_parameters()}
-            buffers = {k: v.detach() for k, v in self._policy.named_buffers()}
+            params = {k: v.detach() for k, v in self._policy_old.named_parameters()}
+            buffers = {k: v.detach() for k, v in self._policy_old.named_buffers()}
             def compute_log_prob(params, buffers, s, a):
                 s = s.unsqueeze(0)
                 a = a.unsqueeze(0)
-                lp = functional_call(self._policy, (params, buffers), (s,a))
+                lp = functional_call(self._policy_old, (params, buffers), (s,a))
                 return lp.sum()
             
             ft_compute_grad = grad(compute_log_prob)
@@ -109,14 +101,14 @@ class CSCAgent():
             ft_per_sample_grads = ft_compute_sample_grad(params, buffers, states, actions)
 
             grads = []
-            for key, param in self._policy.named_parameters():
+            for key, param in self._policy_old.named_parameters():
                 if not param.requires_grad: continue
                 g = ft_per_sample_grads[key].flatten(start_dim=1)
                 grads.append(g)
             return torch.hstack(grads)
-
+        
         @torch.no_grad
-        def fisher_matrix(glog_prob):
+        def fisher_matrix_batched(glog_prob):
             b, n = glog_prob.shape
             s,i = 32,0  # batch size, index
             fisher = torch.zeros((n,n), device=self._device)
@@ -126,11 +118,12 @@ class CSCAgent():
                 i += s
             return fisher
 
-        glog_prob = log_prob_grad(states, actions).detach()
+        glog_prob = log_prob_grad_batched(states, actions).detach()
         with torch.no_grad():
-            fisher = fisher_matrix(glog_prob)
+            fisher = fisher_matrix_batched(glog_prob)
             ahat = estimate_ahat(states, actions, rewards, next_states)
             gJ = (glog_prob * ahat).mean(dim=0).view((-1, 1))
+            del glog_prob, ahat
 
             beta_term = torch.sqrt((2 * self._delta) / (gJ.T @ fisher @ gJ))
             
@@ -138,15 +131,14 @@ class CSCAgent():
             # NOTE: requires full rank on cuda
             #gradient_term = torch.linalg.lstsq(fisher.cpu(), gJ.cpu()).solution.to(self._device)   # too slow
             
-            # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32, --hidden_dim 64 numpy is slow
+            # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32
             time_start = time.time()
             gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0], device=self._device)
             time_end = time.time()
             print(f"[DEBUG] Inverse took: {round(time_end - time_start, 2)}s")
+            del fisher, gJ
 
-            del glog_prob, ahat, fisher, gJ
-
-            dist_old = self._policy.distribution(states)
+            dist_old = self._policy_old.distribution(states)
             beta_j = self._beta
             for j in range(self._line_search_iterations):
                 beta_j = beta_j * (1 - beta_j)**j
@@ -162,6 +154,14 @@ class CSCAgent():
                     apply_gradient(-gradient)
         return torch.nan
 
+    def _update_value_network(self, states, rewards, next_states):
+        value_diff = self._reward_diff(states, rewards, next_states)
+        loss = 1/2 * torch.square(value_diff).mean()
+
+        self._optim_value_network.zero_grad()
+        loss.backward()
+        self._optim_value_network.step()
+        return loss.item()
 
     def _update_safety_critic(self, states, actions, costs, next_states):
         safety_sa_env = self._safety_critic.forward(states, actions)
@@ -185,7 +185,6 @@ class CSCAgent():
         self._optim_safety_critic.step()
         return loss.item()
 
-
     @torch.no_grad
     def _update_lambda(self, states, actions):
         gamma_inv = 1 / (1 - self._gamma)
@@ -195,7 +194,6 @@ class CSCAgent():
         self._lambda -= gradient
         return gradient
 
-
     def update(self, buffer, avg_failures, total_episodes):
         self._avg_failures = avg_failures
         self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures)
@@ -220,6 +218,8 @@ class CSCAgent():
         self._writer.add_scalar(f"agent/safety_loss", round(sloss,4), total_episodes)
         self._writer.add_scalar(f"agent/lambda_gradient", round(lgradient,4), total_episodes)
 
+    def after_updates(self):
+        self.soft_update(self._policy_old, self._policy, tau=1)
 
     @torch.no_grad
     def sample(self, state, shielded=True):
-- 
GitLab