From 145bcce5de7d99182ed2f09d64bb1709b88fbb46 Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Thu, 3 Oct 2024 13:24:47 +0200
Subject: [PATCH] Added lambda lr

---
 main.py       | 6 ++++--
 src/policy.py | 3 ++-
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/main.py b/main.py
index 995ae33..908863c 100644
--- a/main.py
+++ b/main.py
@@ -66,12 +66,14 @@ def cmd_args():
                         help="Set the value of beta (default: 0.7)")
     parser.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N",
                         help="Set the value of alpha (default: 0.5)")
-    parser.add_argument("--csc_lambda", action="store", type=float, default=4e-2, metavar="N",
-                        help="Set the initial value of lambda (default: 4e-2)")
+    parser.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N",
+                        help="Set the initial value of lambda (default: 1.0)")
     parser.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N",
                         help="Learn rate for the safety critic (default: 2e-4)")
     parser.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N",
                         help="Learn rate for the value network (default: 1e-3)")
+    parser.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N",
+                        help="Learn rate for the lambda dual variable (default: 4e-2)")
     parser.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N",
                         help="Hidden dimension of the networks (default: 32)")
 
diff --git a/src/policy.py b/src/policy.py
index 66cd711..4f825f4 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -24,6 +24,7 @@ class CSCAgent():
         self._expectation_estimation_samples = args.expectation_estimation_samples
         self._tau = args.tau
         self._hidden_dim = args.hidden_dim
+        self._lambda_lr = args.csc_lambda_lr
 
         num_inputs = env.observation_space.shape[-1]
         num_actions = env.action_space.shape[-1]
@@ -191,7 +192,7 @@ class CSCAgent():
         adv = self._cost_advantage(states, actions).mean()
         chi_prime = self._chi - self._avg_failures
         gradient = (gamma_inv * adv - chi_prime).item()
-        self._lambda -= gradient
+        self._lambda -= self._lambda_lr * gradient
         return gradient
 
     def update(self, buffer, avg_failures, total_episodes):
-- 
GitLab