diff --git a/main.py b/main.py index 82f260d2f5cde50fffb4f8842306e1a1ce2b3444..24fc27d762551b5ee5b225ee9055f08f4ad5f50f 100644 --- a/main.py +++ b/main.py @@ -66,12 +66,14 @@ def cmd_args(): help="Set the value of beta (default: 0.7)") parser.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N", help="Set the value of alpha (default: 0.5)") - parser.add_argument("--csc_lambda", action="store", type=float, default=4e-2, metavar="N", - help="Set the initial value of lambda (default: 4e-2)") + parser.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N", + help="Set the initial value of lambda (default: 1.0)") parser.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N", help="Learn rate for the safety critic (default: 2e-4)") parser.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N", help="Learn rate for the value network (default: 1e-3)") + parser.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N", + help="Learn rate for the lambda dual variable (default: 4e-2)") parser.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N", help="Hidden dimension of the networks (default: 32)") parser.add_argument("--sigmoid_activation", action="store_true", default=False, diff --git a/src/policy.py b/src/policy.py index 364c16b982cd6f1a1a7a70a1a51c520e1bef872c..de73a72050bb6e03382120a764fcb9d5c922f70a 100644 --- a/src/policy.py +++ b/src/policy.py @@ -24,6 +24,7 @@ class CSCAgent(): self._expectation_estimation_samples = args.expectation_estimation_samples self._tau = args.tau self._hidden_dim = args.hidden_dim + self._lambda_lr = args.csc_lambda_lr num_inputs = env.observation_space.shape[-1] num_actions = env.action_space.shape[-1] @@ -191,7 +192,7 @@ class CSCAgent(): adv = self._cost_advantage(states, actions).mean() chi_prime = self._chi - self._avg_failures gradient = (gamma_inv * adv - chi_prime).item() - self._lambda -= gradient + self._lambda -= self._lambda_lr * gradient return gradient def update(self, buffer, avg_failures, total_episodes):