diff --git a/src/policy.py b/src/policy.py
index 80e07a6371c437e1ac11418a879608273f38f3e4..005d25fbd58577dee2d85f513b17ee63c90f4c61 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -1,8 +1,5 @@
 import torch
-import numpy as np
-import time
 from torch.distributions import kl_divergence
-from torch.func import functional_call, vmap, grad
 from .networks import GaussianPolicy, QNetwork, ValueNetwork
 
 class CSCAgent():
@@ -20,7 +17,6 @@ class CSCAgent():
         self._chi = args.csc_chi
         self._avg_failures = self._chi
         self._batch_size = args.batch_size
-        self._lambda = args.csc_lambda
         self._expectation_estimation_samples = args.expectation_estimation_samples
         self._tau = args.tau
         self._hidden_dim = args.hidden_dim
@@ -32,6 +28,7 @@ class CSCAgent():
         self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
         self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
         self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device)
+        self._lambda = torch.tensor(args.csc_lambda, requires_grad=True, device=self._device)
 
         self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device)
         self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device)
@@ -53,8 +50,29 @@ class CSCAgent():
     def soft_update(target, source, tau):
         for tparam, sparam in zip(target.parameters(), source.parameters()):
             tparam.copy_((1 - tau) * tparam.data + tau * sparam.data)
-
+    
+    @staticmethod
     @torch.no_grad
+    def apply_gradient(network:torch.nn.Module, gradient):
+        l = 0
+        for name, param in network.named_parameters():
+            if not param.requires_grad: continue
+            n = param.numel()
+            param.copy_(param.data + gradient[l:l+n].reshape(param.shape))
+            l += n
+    
+    @staticmethod
+    @torch.no_grad
+    def collect_gradient(network:torch.nn.Module):
+        grads = []
+        for name, param in network.named_parameters():
+            if not param.requires_grad: continue
+            g = torch.zeros_like(param)
+            if not param.grad is None: g += param.grad
+            grads.append(g.flatten())
+        return torch.hstack(grads)
+
+
     def _cost_advantage(self, states, actions):
         cost_actions = self._safety_critic.forward(states, actions)
         cost_states = torch.zeros_like(cost_actions)
@@ -70,91 +88,7 @@ class CSCAgent():
             value_next_states = self._target_value_network.forward(next_states)
             value_target = rewards.view((-1, 1)) + self._gamma * value_next_states
         return value_target - value_states
-    
-    def _update_policy(self, states, actions, rewards, next_states):
-        @torch.no_grad
-        def estimate_ahat(states, actions, rewards, next_states):
-            lambda_prime = self._lambda / (1 - self._gamma)
-            adv_cost = self._cost_advantage(states, actions)
-            reward_diff = self._reward_diff(states, rewards, next_states)
-            return reward_diff - lambda_prime * adv_cost
-        
-        @torch.no_grad
-        def apply_gradient(gradient):
-            l = 0
-            for name, param in self._policy.named_parameters():
-                if not param.requires_grad: continue
-                n = param.numel()
-                param.copy_(param.data + gradient[l:l+n].reshape(param.shape))
-                l += n
-        
-        def log_prob_grad_batched(states, actions):
-            # https://pytorch.org/tutorials/intermediate/per_sample_grads.html
-            params = {k: v.detach() for k, v in self._policy_old.named_parameters()}
-            buffers = {k: v.detach() for k, v in self._policy_old.named_buffers()}
-            def compute_log_prob(params, buffers, s, a):
-                s = s.unsqueeze(0)
-                a = a.unsqueeze(0)
-                lp = functional_call(self._policy_old, (params, buffers), (s,a))
-                return lp.sum()
-            
-            ft_compute_grad = grad(compute_log_prob)
-            ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
-            ft_per_sample_grads = ft_compute_sample_grad(params, buffers, states, actions)
-
-            grads = []
-            for key, param in self._policy_old.named_parameters():
-                if not param.requires_grad: continue
-                g = ft_per_sample_grads[key].flatten(start_dim=1)
-                grads.append(g)
-            return torch.hstack(grads)
-        
-        @torch.no_grad
-        def fisher_matrix_batched(glog_prob):
-            b, n = glog_prob.shape
-            s,i = 32,0  # batch size, index
-            fisher = torch.zeros((n,n), device=self._device)
-            while i < b:
-                g = glog_prob[i:i+s, ...]
-                fisher += (vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b
-                i += s
-            return fisher
-
-        glog_prob = log_prob_grad_batched(states, actions).detach()
-        with torch.no_grad():
-            fisher = fisher_matrix_batched(glog_prob)
-            ahat = estimate_ahat(states, actions, rewards, next_states)
-            gJ = (glog_prob * ahat).mean(dim=0).view((-1, 1))
-            del glog_prob, ahat
-
-            beta_term = torch.sqrt((2 * self._delta) / (gJ.T @ fisher @ gJ))
-            
-            # https://pytorch.org/docs/stable/generated/torch.linalg.pinv.html
-            # NOTE: requires full rank on cuda
-            #gradient_term = torch.linalg.lstsq(fisher.cpu(), gJ.cpu()).solution.to(self._device)   # too slow
-            
-            # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32
-            time_start = time.time()
-            gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0], device=self._device)
-            time_end = time.time()
-            print(f"[DEBUG] Inverse took: {round(time_end - time_start, 2)}s")
-            del fisher, gJ
-
-            dist_old = self._policy_old.distribution(states)
-            beta_j = self._beta
-            for j in range(self._line_search_iterations):
-                beta_j = beta_j * (1 - beta_j)**j
-                gradient = beta_j * beta_term * gradient_term
-
-                apply_gradient(gradient)
-                dist_new = self._policy.distribution(states)
-                kl_div = kl_divergence(dist_old, dist_new).mean()
-                if kl_div <= self._delta:
-                    return j
-                else:
-                    # NOTE: probably better to save the weights and restore them
-                    apply_gradient(-gradient)
-        return torch.nan
+
 
     def _update_value_network(self, states, rewards, next_states):
         value_diff = self._reward_diff(states, rewards, next_states)
@@ -187,14 +121,53 @@ class CSCAgent():
         self._optim_safety_critic.step()
         return loss.item()
 
-    @torch.no_grad
-    def _update_lambda(self, states, actions):
-        gamma_inv = 1 / (1 - self._gamma)
-        adv = self._cost_advantage(states, actions).mean()
-        chi_prime = self._chi - self._avg_failures
-        gradient = (gamma_inv * adv - chi_prime).item()
-        self._lambda -= self._lambda_lr * gradient
-        return gradient
+    def _primal_dual_gradient(self, states, actions, rewards, next_states, total_episodes):
+        # Equation 45
+        reward_advantage = self._reward_diff(states, rewards, next_states).detach()
+        lambda_prime = self._lambda / (1 - self._gamma)
+        chi_prime_term = self._lambda * (self._chi - self._avg_failures)
+        cost_advantage = self._cost_advantage(states, actions).detach()
+        a,b = self._policy.log_prob(states, actions), self._policy_old.log_prob(states, actions).detach()
+        ratio = (a - b).exp()
+
+        objective = (ratio * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term
+
+        self._policy.zero_grad()
+        self._lambda.grad = None
+        objective.backward()
+
+        update_lambda = self._lambda_lr * self._lambda.grad
+        with torch.no_grad():
+            self._lambda.copy_(self._lambda.data + update_lambda)  # descent
+
+        self._writer.add_scalar(f"debug/objective", round(objective.item(),4), total_episodes)
+        self._writer.add_scalar(f"debug/lambda", round(self._lambda.item(),4), total_episodes)
+
+        dist_old = self._policy_old.distribution(states)
+        gradient = self.collect_gradient(self._policy)
+        self.apply_gradient(self._policy, gradient)     # ascent
+        dist_new = self._policy.distribution(states)
+        self.apply_gradient(self._policy, -gradient)    # descent back
+        kl_div = kl_divergence(dist_old, dist_new).mean()
+
+        beta_term = torch.sqrt(self._delta / kl_div)
+        beta_j = self._beta
+        for j in range(self._line_search_iterations):
+            beta_j = beta_j * (1 - beta_j)**j
+            beta = beta_j * beta_term
+
+            update_policy = beta * gradient
+            self.apply_gradient(self._policy, update_policy)      # ascent
+            dist_new = self._policy.distribution(states)
+            kl_div = kl_divergence(dist_old, dist_new).mean()
+
+            if kl_div <= self._delta:
+                return j, update_lambda.item()
+            else:
+                self.apply_gradient(self._policy, -update_policy)     # descent back
+
+        return torch.nan, update_lambda.item()
+
 
     def update(self, buffer, avg_failures, total_episodes):
         self._avg_failures = avg_failures
@@ -207,10 +180,9 @@ class CSCAgent():
         costs = torch.tensor(costs, device=self._device)
         next_states = torch.tensor(next_states, device=self._device)
 
-        piter = self._update_policy(states, actions, rewards, next_states)
         vloss = self._update_value_network(states, rewards, next_states)
         sloss = self._update_safety_critic(states, actions, costs, next_states)
-        lgradient = self._update_lambda(states, actions)
+        piter, lgradient = self._primal_dual_gradient(states, actions, rewards, next_states, total_episodes)
 
         self.soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau)
         self.soft_update(self._target_value_network, self._value_network, tau=self._tau)