diff --git a/src/policy.py b/src/policy.py index 80e07a6371c437e1ac11418a879608273f38f3e4..005d25fbd58577dee2d85f513b17ee63c90f4c61 100644 --- a/src/policy.py +++ b/src/policy.py @@ -1,8 +1,5 @@ import torch -import numpy as np -import time from torch.distributions import kl_divergence -from torch.func import functional_call, vmap, grad from .networks import GaussianPolicy, QNetwork, ValueNetwork class CSCAgent(): @@ -20,7 +17,6 @@ class CSCAgent(): self._chi = args.csc_chi self._avg_failures = self._chi self._batch_size = args.batch_size - self._lambda = args.csc_lambda self._expectation_estimation_samples = args.expectation_estimation_samples self._tau = args.tau self._hidden_dim = args.hidden_dim @@ -32,6 +28,7 @@ class CSCAgent(): self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device) self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device) + self._lambda = torch.tensor(args.csc_lambda, requires_grad=True, device=self._device) self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device) @@ -53,8 +50,29 @@ class CSCAgent(): def soft_update(target, source, tau): for tparam, sparam in zip(target.parameters(), source.parameters()): tparam.copy_((1 - tau) * tparam.data + tau * sparam.data) - + + @staticmethod @torch.no_grad + def apply_gradient(network:torch.nn.Module, gradient): + l = 0 + for name, param in network.named_parameters(): + if not param.requires_grad: continue + n = param.numel() + param.copy_(param.data + gradient[l:l+n].reshape(param.shape)) + l += n + + @staticmethod + @torch.no_grad + def collect_gradient(network:torch.nn.Module): + grads = [] + for name, param in network.named_parameters(): + if not param.requires_grad: continue + g = torch.zeros_like(param) + if not param.grad is None: g += param.grad + grads.append(g.flatten()) + return torch.hstack(grads) + + def _cost_advantage(self, states, actions): cost_actions = self._safety_critic.forward(states, actions) cost_states = torch.zeros_like(cost_actions) @@ -70,91 +88,7 @@ class CSCAgent(): value_next_states = self._target_value_network.forward(next_states) value_target = rewards.view((-1, 1)) + self._gamma * value_next_states return value_target - value_states - - def _update_policy(self, states, actions, rewards, next_states): - @torch.no_grad - def estimate_ahat(states, actions, rewards, next_states): - lambda_prime = self._lambda / (1 - self._gamma) - adv_cost = self._cost_advantage(states, actions) - reward_diff = self._reward_diff(states, rewards, next_states) - return reward_diff - lambda_prime * adv_cost - - @torch.no_grad - def apply_gradient(gradient): - l = 0 - for name, param in self._policy.named_parameters(): - if not param.requires_grad: continue - n = param.numel() - param.copy_(param.data + gradient[l:l+n].reshape(param.shape)) - l += n - - def log_prob_grad_batched(states, actions): - # https://pytorch.org/tutorials/intermediate/per_sample_grads.html - params = {k: v.detach() for k, v in self._policy_old.named_parameters()} - buffers = {k: v.detach() for k, v in self._policy_old.named_buffers()} - def compute_log_prob(params, buffers, s, a): - s = s.unsqueeze(0) - a = a.unsqueeze(0) - lp = functional_call(self._policy_old, (params, buffers), (s,a)) - return lp.sum() - - ft_compute_grad = grad(compute_log_prob) - ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) - ft_per_sample_grads = ft_compute_sample_grad(params, buffers, states, actions) - - grads = [] - for key, param in self._policy_old.named_parameters(): - if not param.requires_grad: continue - g = ft_per_sample_grads[key].flatten(start_dim=1) - grads.append(g) - return torch.hstack(grads) - - @torch.no_grad - def fisher_matrix_batched(glog_prob): - b, n = glog_prob.shape - s,i = 32,0 # batch size, index - fisher = torch.zeros((n,n), device=self._device) - while i < b: - g = glog_prob[i:i+s, ...] - fisher += (vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b - i += s - return fisher - - glog_prob = log_prob_grad_batched(states, actions).detach() - with torch.no_grad(): - fisher = fisher_matrix_batched(glog_prob) - ahat = estimate_ahat(states, actions, rewards, next_states) - gJ = (glog_prob * ahat).mean(dim=0).view((-1, 1)) - del glog_prob, ahat - - beta_term = torch.sqrt((2 * self._delta) / (gJ.T @ fisher @ gJ)) - - # https://pytorch.org/docs/stable/generated/torch.linalg.pinv.html - # NOTE: requires full rank on cuda - #gradient_term = torch.linalg.lstsq(fisher.cpu(), gJ.cpu()).solution.to(self._device) # too slow - - # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32 - time_start = time.time() - gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0], device=self._device) - time_end = time.time() - print(f"[DEBUG] Inverse took: {round(time_end - time_start, 2)}s") - del fisher, gJ - - dist_old = self._policy_old.distribution(states) - beta_j = self._beta - for j in range(self._line_search_iterations): - beta_j = beta_j * (1 - beta_j)**j - gradient = beta_j * beta_term * gradient_term - - apply_gradient(gradient) - dist_new = self._policy.distribution(states) - kl_div = kl_divergence(dist_old, dist_new).mean() - if kl_div <= self._delta: - return j - else: - # NOTE: probably better to save the weights and restore them - apply_gradient(-gradient) - return torch.nan + def _update_value_network(self, states, rewards, next_states): value_diff = self._reward_diff(states, rewards, next_states) @@ -187,14 +121,53 @@ class CSCAgent(): self._optim_safety_critic.step() return loss.item() - @torch.no_grad - def _update_lambda(self, states, actions): - gamma_inv = 1 / (1 - self._gamma) - adv = self._cost_advantage(states, actions).mean() - chi_prime = self._chi - self._avg_failures - gradient = (gamma_inv * adv - chi_prime).item() - self._lambda -= self._lambda_lr * gradient - return gradient + def _primal_dual_gradient(self, states, actions, rewards, next_states, total_episodes): + # Equation 45 + reward_advantage = self._reward_diff(states, rewards, next_states).detach() + lambda_prime = self._lambda / (1 - self._gamma) + chi_prime_term = self._lambda * (self._chi - self._avg_failures) + cost_advantage = self._cost_advantage(states, actions).detach() + a,b = self._policy.log_prob(states, actions), self._policy_old.log_prob(states, actions).detach() + ratio = (a - b).exp() + + objective = (ratio * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term + + self._policy.zero_grad() + self._lambda.grad = None + objective.backward() + + update_lambda = self._lambda_lr * self._lambda.grad + with torch.no_grad(): + self._lambda.copy_(self._lambda.data + update_lambda) # descent + + self._writer.add_scalar(f"debug/objective", round(objective.item(),4), total_episodes) + self._writer.add_scalar(f"debug/lambda", round(self._lambda.item(),4), total_episodes) + + dist_old = self._policy_old.distribution(states) + gradient = self.collect_gradient(self._policy) + self.apply_gradient(self._policy, gradient) # ascent + dist_new = self._policy.distribution(states) + self.apply_gradient(self._policy, -gradient) # descent back + kl_div = kl_divergence(dist_old, dist_new).mean() + + beta_term = torch.sqrt(self._delta / kl_div) + beta_j = self._beta + for j in range(self._line_search_iterations): + beta_j = beta_j * (1 - beta_j)**j + beta = beta_j * beta_term + + update_policy = beta * gradient + self.apply_gradient(self._policy, update_policy) # ascent + dist_new = self._policy.distribution(states) + kl_div = kl_divergence(dist_old, dist_new).mean() + + if kl_div <= self._delta: + return j, update_lambda.item() + else: + self.apply_gradient(self._policy, -update_policy) # descent back + + return torch.nan, update_lambda.item() + def update(self, buffer, avg_failures, total_episodes): self._avg_failures = avg_failures @@ -207,10 +180,9 @@ class CSCAgent(): costs = torch.tensor(costs, device=self._device) next_states = torch.tensor(next_states, device=self._device) - piter = self._update_policy(states, actions, rewards, next_states) vloss = self._update_value_network(states, rewards, next_states) sloss = self._update_safety_critic(states, actions, costs, next_states) - lgradient = self._update_lambda(states, actions) + piter, lgradient = self._primal_dual_gradient(states, actions, rewards, next_states, total_episodes) self.soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau) self.soft_update(self._target_value_network, self._value_network, tau=self._tau)