From c2058916bb87617314ed0c01223abd320b8bab0c Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Thu, 20 Feb 2025 14:39:57 +0100 Subject: [PATCH] Temporary helper, policy, networks commit --- src/helper.py | 27 ++++ src/networks.py | 34 +++-- src/policy.py | 350 ++++++++++++++++++++++++++++-------------------- 3 files changed, 255 insertions(+), 156 deletions(-) create mode 100644 src/helper.py diff --git a/src/helper.py b/src/helper.py new file mode 100644 index 0000000..05ea49a --- /dev/null +++ b/src/helper.py @@ -0,0 +1,27 @@ +import torch +from torch.nn import Module + +@torch.no_grad +def soft_update(target:Module, source:Module, tau:float): + """ + Performs a soft parameter update of the target network. + """ + for tparam, sparam in zip(target.parameters(), source.parameters()): + tparam.data.copy_(tau * sparam.data + (1.0 - tau) * tparam.data) + +@torch.no_grad +def hard_update(target:Module, source:Module): + """ + Performs a hard parameter update of the target network. Copies the source parameter's data into target. + """ + soft_update(target, source, tau=1.0) + +@torch.no_grad +def apply_gradient(network:Module, lr): + """ + Uses model gradients and a learning rate to update all model parameters. + """ + for name, param in network.named_parameters(): + if not param.requires_grad: continue + if param.grad is None: continue + param += lr * param.grad \ No newline at end of file diff --git a/src/networks.py b/src/networks.py index 0ba8102..5d73c34 100644 --- a/src/networks.py +++ b/src/networks.py @@ -3,12 +3,9 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributions import Normal -""" -Copied from spice project. -""" - LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 +epsilon = 1e-6 # Initialize Policy weights def weights_init_(m): @@ -85,14 +82,35 @@ class GaussianPolicy(nn.Module): return mean, log_std def distribution(self, state): - mean, log_std = self.forward(state) - std = log_std.exp() - normal = Normal(mean, std) + try: + mean, log_std = self.forward(state) + std = log_std.exp() + normal = Normal(mean, std) + except ValueError: + print("state:", state) + print(mean) + print(log_std) + print(std) + + print("linear1:", self.linear1) + print(self.linear2) + print(self.mean_linear) + print(self.log_std_linear) + + print("bias:", self.action_bias) + print(self.action_scale) + exit(0) + return normal def log_prob(self, state, action): dist = self.distribution(state) - return dist.log_prob(action) + y_t = (action - self.action_bias) / self.action_scale + x_t = torch.atanh(y_t) + log_prob = dist.log_prob(x_t) + log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) + log_prob = log_prob.sum(axis=-1, keepdim=True) + return log_prob def sample(self, state, num_samples=1): normal = self.distribution(state) diff --git a/src/policy.py b/src/policy.py index 005d25f..ae73543 100644 --- a/src/policy.py +++ b/src/policy.py @@ -1,10 +1,13 @@ import torch from torch.distributions import kl_divergence -from .networks import GaussianPolicy, QNetwork, ValueNetwork +from torch.nn.functional import mse_loss +from src.networks import GaussianPolicy, QNetwork, ValueNetwork +from src.helper import soft_update, hard_update, apply_gradient class CSCAgent(): - def __init__(self, env, args, writer) -> None: - self._writer = writer + def __init__(self, env, args, buffer, stats) -> None: + self._buffer = buffer + self._stats = stats self._device = args.device self._shield_iterations = args.shield_iterations @@ -15,7 +18,7 @@ class CSCAgent(): self._gamma = args.csc_gamma self._delta = args.csc_delta self._chi = args.csc_chi - self._avg_failures = self._chi + self._avg_train_unsafe = self._chi self._batch_size = args.batch_size self._expectation_estimation_samples = args.expectation_estimation_samples self._tau = args.tau @@ -28,219 +31,270 @@ class CSCAgent(): self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device) self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device) - self._lambda = torch.tensor(args.csc_lambda, requires_grad=True, device=self._device) + self._lambda = torch.nn.Parameter(torch.tensor(args.csc_lambda, requires_grad=True, device=self._device)) self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device) self._target_value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device) - self.soft_update(self._policy_old, self._policy, tau=1) - self.soft_update(self._target_safety_critic, self._safety_critic, tau=1) - self.soft_update(self._target_value_network, self._value_network, tau=1) + hard_update(self._policy_old, self._policy) + hard_update(self._target_safety_critic, self._safety_critic) + hard_update(self._target_value_network, self._value_network) self._optim_safety_critic = torch.optim.Adam(self._safety_critic.parameters(), lr=args.csc_safety_critic_lr) self._optim_value_network = torch.optim.Adam(self._value_network.parameters(), lr=args.csc_value_network_lr) - self._policy.eval() - self._safety_critic.eval() - self._value_network.eval() - - - @staticmethod - @torch.no_grad - def soft_update(target, source, tau): - for tparam, sparam in zip(target.parameters(), source.parameters()): - tparam.copy_((1 - tau) * tparam.data + tau * sparam.data) - - @staticmethod - @torch.no_grad - def apply_gradient(network:torch.nn.Module, gradient): - l = 0 - for name, param in network.named_parameters(): - if not param.requires_grad: continue - n = param.numel() - param.copy_(param.data + gradient[l:l+n].reshape(param.shape)) - l += n - - @staticmethod - @torch.no_grad - def collect_gradient(network:torch.nn.Module): - grads = [] - for name, param in network.named_parameters(): - if not param.requires_grad: continue - g = torch.zeros_like(param) - if not param.grad is None: g += param.grad - grads.append(g.flatten()) - return torch.hstack(grads) - - - def _cost_advantage(self, states, actions): - cost_actions = self._safety_critic.forward(states, actions) - cost_states = torch.zeros_like(cost_actions) - for _ in range(self._expectation_estimation_samples): - a = self._action_for_update(policy=self._policy_old, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach() - cost_states += self._safety_critic.forward(states, a) - cost_states /= self._expectation_estimation_samples - return cost_actions - cost_states - - def _reward_diff(self, states, rewards, next_states): + # TODO + self._optim_policy = torch.optim.Adam(self._policy.parameters(), lr=args.csc_safety_critic_lr) + + def _update_value_network(self, states, rewards, next_states, dones): + """ + Updates the reward value network using MSE on a batch of experiences. + """ value_states = self._value_network.forward(states) with torch.no_grad(): value_next_states = self._target_value_network.forward(next_states) + value_next_states *= (1 - dones.view((-1, 1))) value_target = rewards.view((-1, 1)) + self._gamma * value_next_states - return value_target - value_states - - - def _update_value_network(self, states, rewards, next_states): - value_diff = self._reward_diff(states, rewards, next_states) - loss = 1/2 * torch.square(value_diff).mean() + loss = mse_loss(value_states, value_target, reduction='mean') self._optim_value_network.zero_grad() loss.backward() self._optim_value_network.step() return loss.item() - def _update_safety_critic(self, states, actions, costs, next_states): + + def _update_safety_critic(self, states, actions, costs, next_states, dones): + """ + Updates the safety critic network using the CQL objective (Equation 2) from the CSC paper. + """ + # states, action from old policy (from buffer) safety_sa_env = self._safety_critic.forward(states, actions) - a = self._action_for_update(policy=self._policy, state=states, shielded=self._shielded_action_sampling).squeeze(0).detach() - safety_s_env_a_p = self._safety_critic.forward(states, a) + # states from old policy (from buffer), actions from current policy + with torch.no_grad(): + actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling) + safety_s_env_a_p = self._safety_critic.forward(states, actions_current_policy) + + # first loss term (steers critic towards overapproximation) + loss1 = (-safety_s_env_a_p.mean() + safety_sa_env.mean()) + # bellman operator with torch.no_grad(): safety_next_state = torch.zeros_like(safety_sa_env) for _ in range(self._expectation_estimation_samples): - a = self._action_for_update(policy=self._policy, state=next_states, shielded=self._shielded_action_sampling).squeeze(0).detach() - safety_next_state += self._target_safety_critic.forward(next_states, a) + actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling) + safety_next_state += self._target_safety_critic.forward(next_states, actions_current_policy) safety_next_state /= self._expectation_estimation_samples - safety_next_state = costs.view((-1, 1)) + self._gamma * safety_next_state - safety_sasc_env = torch.square(safety_sa_env - safety_next_state) + safety_next_state *= (1 - dones.view((-1, 1))) + safety_target = costs.view((-1, 1)) + self._gamma * safety_next_state + + # second loss term (mse loss) + loss2 = mse_loss(safety_sa_env, safety_target, reduction='mean') - loss = self._alpha * (safety_sa_env.mean() - safety_s_env_a_p.mean()) + 1/2 * safety_sasc_env.mean() + # overall weighted loss as sum of loss1 and loss2 + loss = self._alpha * loss1 + 1/2 * loss2 self._optim_safety_critic.zero_grad() loss.backward() self._optim_safety_critic.step() return loss.item() - def _primal_dual_gradient(self, states, actions, rewards, next_states, total_episodes): - # Equation 45 - reward_advantage = self._reward_diff(states, rewards, next_states).detach() + def _tmp(self, states, actions, rewards, next_states, dones): + # importance sampling weights + action_log_prob = self._policy.log_prob(states, actions) + action_log_prob_old = self._policy_old.log_prob(states, actions).detach() + # avoid underflow by subtracting the max + max_log_prob = torch.max(torch.stack((action_log_prob, action_log_prob_old))) + action_log_prob -= max_log_prob + action_log_prob_old -= max_log_prob + # calculate importance ratio + weighting:torch.Tensor = (action_log_prob - action_log_prob_old).exp() + + # reward advantage + with torch.no_grad(): + value_states = self._target_value_network.forward(states) + value_next_states = self._target_value_network.forward(next_states) + value_next_states *= (1 - dones.view((-1, 1))) + value_target = rewards.view((-1, 1)) + self._gamma * value_next_states + reward_advantage = value_target - value_states + objective = (weighting * reward_advantage).mean() + # objective = (weighting * value_states).mean() + + #objective *= -1 + #self._optim_policy.zero_grad() + #objective.backward() + self._policy.zero_grad() + objective.backward() + if False: + print("alp:", action_log_prob) + print(action_log_prob_old) + print("weighting:", weighting) + print(objective) + print("lin1grad:", self._policy.linear1.weight.grad) + print(self._policy.linear2.weight.grad) + print("meanlingrad:", self._policy.mean_linear.weight.grad) + print(self._policy.log_std_linear.weight.grad) + apply_gradient(self._policy, lr=self._lambda_lr) + #self._optim_policy.step() + + return objective.item() + + + def _primal_dual_gradient(self, states, actions, rewards, next_states, dones): + """ + Updates the policy and lambda parameter. Calculates the objective from equation 45 from the CSC paper. + """ + return self._tmp(states, actions, rewards, next_states, dones) + # importance sampling factor + action_log_prob = self._policy.log_prob(states, actions) + action_log_prob_old = self._policy_old.log_prob(states, actions).detach() + weighting = (action_log_prob - action_log_prob_old).exp() + + # reward advantage + with torch.no_grad(): + value_states = self._target_value_network.forward(states) + value_next_states = self._target_value_network.forward(next_states) + value_next_states *= (1 - dones.view((-1, 1))) + value_target = rewards.view((-1, 1)) + self._gamma * value_next_states + reward_advantage = value_target - value_states + + # lambda prime lambda_prime = self._lambda / (1 - self._gamma) - chi_prime_term = self._lambda * (self._chi - self._avg_failures) - cost_advantage = self._cost_advantage(states, actions).detach() - a,b = self._policy.log_prob(states, actions), self._policy_old.log_prob(states, actions).detach() - ratio = (a - b).exp() - objective = (ratio * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term + # cost advantage + with torch.no_grad(): + cost_actions = self._target_safety_critic.forward(states, actions) + cost_states = torch.zeros_like(cost_actions) + for _ in range(self._expectation_estimation_samples): + actions_policy_old = self._sample_policy_old(states, shielded=self._shielded_action_sampling) + cost_states += self._target_safety_critic.forward(states, actions_policy_old) + cost_states /= self._expectation_estimation_samples + cost_advantage = cost_actions - cost_states + + # chi prime term (eq. 44) + chi_prime_term = self._lambda * (self._chi - self._avg_train_unsafe) + + # overall objective + objective = (weighting * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term + # reset gradients self._policy.zero_grad() self._lambda.grad = None + + # calculate gradients objective.backward() - update_lambda = self._lambda_lr * self._lambda.grad - with torch.no_grad(): - self._lambda.copy_(self._lambda.data + update_lambda) # descent + # apply policy gradient + apply_gradient(self._policy, lr=1) # gradient ascent - self._writer.add_scalar(f"debug/objective", round(objective.item(),4), total_episodes) - self._writer.add_scalar(f"debug/lambda", round(self._lambda.item(),4), total_episodes) + # obtain action distributions + action_dist_old = self._policy_old.distribution(states) + action_dist_new = self._policy.distribution(states) + kl_div = kl_divergence(action_dist_old, action_dist_new).mean() - dist_old = self._policy_old.distribution(states) - gradient = self.collect_gradient(self._policy) - self.apply_gradient(self._policy, gradient) # ascent - dist_new = self._policy.distribution(states) - self.apply_gradient(self._policy, -gradient) # descent back - kl_div = kl_divergence(dist_old, dist_new).mean() + # revert update + apply_gradient(self._policy, lr=-1) # gradient descent back + # calculate beta beta_term = torch.sqrt(self._delta / kl_div) beta_j = self._beta + + # line search for j in range(self._line_search_iterations): beta_j = beta_j * (1 - beta_j)**j beta = beta_j * beta_term - update_policy = beta * gradient - self.apply_gradient(self._policy, update_policy) # ascent - dist_new = self._policy.distribution(states) - kl_div = kl_divergence(dist_old, dist_new).mean() + apply_gradient(self._policy, lr=beta) # gradient ascent + action_dist_new = self._policy.distribution(states) + kl_div = kl_divergence(action_dist_old, action_dist_new).mean() - if kl_div <= self._delta: - return j, update_lambda.item() + if kl_div > self._delta: + apply_gradient(self._policy, lr=-beta) # gradient descent back else: - self.apply_gradient(self._policy, -update_policy) # descent back - - return torch.nan, update_lambda.item() + break + # update lambda + with torch.no_grad(): + self._lambda -= self._lambda_lr * self._lambda.grad # gradient descent + self._lambda.clamp_min_(min=0) + + return j - def update(self, buffer, avg_failures, total_episodes): - self._avg_failures = avg_failures - self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures) - states, actions, rewards, costs, next_states = buffer.sample(self._batch_size) + def update(self): + """ + Performs one iteration of updates. Updates the value network, policy network, lambda parameter and safety critic. + """ + self._avg_train_unsafe = self._stats.avg_train_unsafe + + states, actions, rewards, costs, next_states, dones = self._buffer.sample(self._batch_size) states = torch.tensor(states, device=self._device) actions = torch.tensor(actions, device=self._device) rewards = torch.tensor(rewards, device=self._device) costs = torch.tensor(costs, device=self._device) next_states = torch.tensor(next_states, device=self._device) + dones = torch.tensor(dones, device=self._device) - vloss = self._update_value_network(states, rewards, next_states) - sloss = self._update_safety_critic(states, actions, costs, next_states) - piter, lgradient = self._primal_dual_gradient(states, actions, rewards, next_states, total_episodes) + vloss = self._update_value_network(states, rewards, next_states, dones) + soft_update(self._target_value_network, self._value_network, tau=self._tau) + piter = self._primal_dual_gradient(states, actions, rewards, next_states, dones) + sloss = self._update_safety_critic(states, actions, costs, next_states, dones) + soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau) - self.soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau) - self.soft_update(self._target_value_network, self._value_network, tau=self._tau) + self._stats.writer.add_scalar("debug/vloss", vloss, self._stats.total_train_episodes) + self._stats.writer.add_scalar("debug/sloss", sloss, self._stats.total_train_episodes) + self._stats.writer.add_scalar("debug/piter", piter, self._stats.total_train_episodes) + self._stats.writer.add_scalar("debug/lambda", self._lambda, self._stats.total_train_episodes) - self._writer.add_scalar(f"agent/policy_iterations", piter, total_episodes) - self._writer.add_scalar(f"agent/value_loss", round(vloss,4), total_episodes) - self._writer.add_scalar(f"agent/safety_loss", round(sloss,4), total_episodes) - self._writer.add_scalar(f"agent/lambda_gradient", round(lgradient,4), total_episodes) def after_updates(self): - self.soft_update(self._policy_old, self._policy, tau=1) + self._unsafety_threshold = (1 - self._gamma) * (self._chi - self._avg_train_unsafe) + hard_update(self._policy_old, self._policy) - @torch.no_grad - def sample(self, state, shielded=True): - state = torch.tensor(state, device=self._device) - if shielded: - action = self._policy.sample(state, num_samples=self._shield_iterations) - - state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1)) - unsafety = self._safety_critic.forward(state, action).squeeze(2).permute(1,0) - - mask = unsafety <= self._unsafety_threshold - argmax = mask.int().argmax(dim=1) # col idx of first "potentially" safe action - is_zero = mask[torch.arange(0, state.shape[1]), argmax] == 0 # check if in a row every col is unsafe - action = action.permute(1, 0, 2) - result = torch.zeros((state.shape[1], action.shape[-1]), device=self._device) - result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...] # minimum in case only unsafe actions - result[~is_zero] = action[~is_zero, argmax[~is_zero], ...] # else first safe action + def _sample_policy_old(self, state, shielded): + """ + Allows action sampling from policy_old. + """ + tmp = self._policy + self._policy = self._policy_old + actions = self.sample(state, shielded=shielded) + self._policy = tmp + return actions - return result.cpu().numpy() - - else: - action = self._policy.sample(state) - return action.squeeze(0).cpu().numpy() - - @torch.no_grad - def _action_for_update(self, policy, state, shielded=True, return_dist=False): + def sample(self, state, shielded=True): + """ + Samples and returns one action for every state. If shielded is true, performs rejection sampling according to the CSC paper using the safety critic. + Instead of using a loop, we sample all actions simultaneously and pick: + - the first with an estimated unsafety level <= self._unsafety_threshold (epsilon in the CSC paper) + - or else, the one that achieves maximum safety, i.e., lowest estimated unsafety + """ + state = torch.tensor(state, device=self._device, dtype=torch.float64) if shielded: - dist = policy.distribution(state) - action = dist.sample((self._shield_iterations,)) + # sample all actions, expand/copy state, estimate safety for all actions at the same time + actions = self._policy.sample(state, num_samples=self._shield_iterations) # shape: (shield_iterations, batch_size, action_size) + state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1)) # shape: (shield_iterations, batch_size, state_size) + unsafety = self._safety_critic.forward(state, actions) # shape: (shield_iterations, batch_size, 1) + unsafety = unsafety.squeeze(2) # shape: (shield_iterations, batch_size) - state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1)) - unsafety = self._safety_critic.forward(state, action).squeeze(2).permute(1,0) + # check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state + mask = unsafety <= self._unsafety_threshold # shape: (shield_iterations, batch_size) + row_idx = mask.int().argmax(dim=0) # idx of first "potentially" safe actions + batch_idx = torch.arange(0, mask.shape[1]) - mask = unsafety <= self._unsafety_threshold - argmax = mask.int().argmax(dim=1) # col idx of first "potentially" safe action - is_zero = mask[torch.arange(0, state.shape[1]), argmax] == 0 # check if in a row every col is unsafe + # retrieve estimated safety of said action and check if it is indeed <= threshold + # if for a state all actions are unsafe (> threshold), argmax will return the first unsafe as mask[state] == 0 everywhere + unsafety_chosen = unsafety[row_idx, batch_idx] + all_unsafe = unsafety_chosen > self._unsafety_threshold - action = action.permute(1, 0, 2) - result = torch.zeros((state.shape[1], action.shape[-1]), device=self._device) - result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...] # minimum in case only unsafe actions - result[~is_zero] = action[~is_zero, argmax[~is_zero], ...] # else first safe action + # for such states retrieve the action with a minimum level of unsafety + row_idx[all_unsafe] = unsafety[:, all_unsafe].argmin(dim=0) + + # sample and return actions according to idxs + result = actions[row_idx, batch_idx, :] + return result else: - dist = policy.distribution(state) - result = dist.sample().squeeze(0) - - if return_dist: - return result, dist - return result \ No newline at end of file + # simply sample for every state one action + actions = self._policy.sample(state) # shape: (num_samples, batch_size, action_size) + return actions.squeeze(0) \ No newline at end of file -- GitLab