diff --git a/src/policy.py b/src/policy.py index 2ef02cd9f7532430f364c470fbb3c6f66394a780..66cd7112c3ea86b38368faa86e54537dceac8a46 100644 --- a/src/policy.py +++ b/src/policy.py @@ -23,16 +23,18 @@ class CSCAgent(): self._lambda = args.csc_lambda self._expectation_estimation_samples = args.expectation_estimation_samples self._tau = args.tau + self._hidden_dim = args.hidden_dim num_inputs = env.observation_space.shape[-1] num_actions = env.action_space.shape[-1] - hidden_dim = args.hidden_dim - self._policy = GaussianPolicy(num_inputs, num_actions, hidden_dim, env.action_space).to(self._device) - self._safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device) - self._value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device) - - self._target_safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device) - self._target_value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device) + self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) + self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim).to(self._device) + self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device) + + self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) + self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim).to(self._device) + self._target_value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device) + self.soft_update(self._policy_old, self._policy, tau=1) self.soft_update(self._target_safety_critic, self._safety_critic, tau=1) self.soft_update(self._target_value_network, self._value_network, tau=1) @@ -41,7 +43,9 @@ class CSCAgent(): self._policy.eval() self._safety_critic.eval() - + self._value_network.eval() + + @staticmethod @torch.no_grad def soft_update(target, source, tau): @@ -53,30 +57,18 @@ class CSCAgent(): cost_actions = self._safety_critic.forward(states, actions) cost_states = torch.zeros_like(cost_actions) for _ in range(self._expectation_estimation_samples): - a = self._policy.sample(states).squeeze(0).detach() - cost_states += self._target_safety_critic.forward(states, a) + a = self._policy_old.sample(states).squeeze(0).detach() + cost_states += self._safety_critic.forward(states, a) cost_states /= self._expectation_estimation_samples return cost_actions - cost_states - def _reward_diff(self, states, rewards, next_states): value_states = self._value_network.forward(states) with torch.no_grad(): value_next_states = self._target_value_network.forward(next_states) value_target = rewards.view((-1, 1)) + self._gamma * value_next_states return value_target - value_states - - - def _update_value_network(self, states, rewards, next_states): - value_diff = self._reward_diff(states, rewards, next_states) - loss = 1/2 * torch.square(value_diff).mean() - - self._optim_value_network.zero_grad() - loss.backward() - self._optim_value_network.step() - return loss.item() - - + def _update_policy(self, states, actions, rewards, next_states): @torch.no_grad def estimate_ahat(states, actions, rewards, next_states): @@ -94,14 +86,14 @@ class CSCAgent(): param.copy_(param.data + gradient[l:l+n].reshape(param.shape)) l += n - def log_prob_grad(states, actions): + def log_prob_grad_batched(states, actions): # https://pytorch.org/tutorials/intermediate/per_sample_grads.html - params = {k: v.detach() for k, v in self._policy.named_parameters()} - buffers = {k: v.detach() for k, v in self._policy.named_buffers()} + params = {k: v.detach() for k, v in self._policy_old.named_parameters()} + buffers = {k: v.detach() for k, v in self._policy_old.named_buffers()} def compute_log_prob(params, buffers, s, a): s = s.unsqueeze(0) a = a.unsqueeze(0) - lp = functional_call(self._policy, (params, buffers), (s,a)) + lp = functional_call(self._policy_old, (params, buffers), (s,a)) return lp.sum() ft_compute_grad = grad(compute_log_prob) @@ -109,14 +101,14 @@ class CSCAgent(): ft_per_sample_grads = ft_compute_sample_grad(params, buffers, states, actions) grads = [] - for key, param in self._policy.named_parameters(): + for key, param in self._policy_old.named_parameters(): if not param.requires_grad: continue g = ft_per_sample_grads[key].flatten(start_dim=1) grads.append(g) return torch.hstack(grads) - + @torch.no_grad - def fisher_matrix(glog_prob): + def fisher_matrix_batched(glog_prob): b, n = glog_prob.shape s,i = 32,0 # batch size, index fisher = torch.zeros((n,n), device=self._device) @@ -126,11 +118,12 @@ class CSCAgent(): i += s return fisher - glog_prob = log_prob_grad(states, actions).detach() + glog_prob = log_prob_grad_batched(states, actions).detach() with torch.no_grad(): - fisher = fisher_matrix(glog_prob) + fisher = fisher_matrix_batched(glog_prob) ahat = estimate_ahat(states, actions, rewards, next_states) gJ = (glog_prob * ahat).mean(dim=0).view((-1, 1)) + del glog_prob, ahat beta_term = torch.sqrt((2 * self._delta) / (gJ.T @ fisher @ gJ)) @@ -138,15 +131,14 @@ class CSCAgent(): # NOTE: requires full rank on cuda #gradient_term = torch.linalg.lstsq(fisher.cpu(), gJ.cpu()).solution.to(self._device) # too slow - # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32, --hidden_dim 64 numpy is slow + # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32 time_start = time.time() gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0], device=self._device) time_end = time.time() print(f"[DEBUG] Inverse took: {round(time_end - time_start, 2)}s") + del fisher, gJ - del glog_prob, ahat, fisher, gJ - - dist_old = self._policy.distribution(states) + dist_old = self._policy_old.distribution(states) beta_j = self._beta for j in range(self._line_search_iterations): beta_j = beta_j * (1 - beta_j)**j @@ -162,6 +154,14 @@ class CSCAgent(): apply_gradient(-gradient) return torch.nan + def _update_value_network(self, states, rewards, next_states): + value_diff = self._reward_diff(states, rewards, next_states) + loss = 1/2 * torch.square(value_diff).mean() + + self._optim_value_network.zero_grad() + loss.backward() + self._optim_value_network.step() + return loss.item() def _update_safety_critic(self, states, actions, costs, next_states): safety_sa_env = self._safety_critic.forward(states, actions) @@ -185,7 +185,6 @@ class CSCAgent(): self._optim_safety_critic.step() return loss.item() - @torch.no_grad def _update_lambda(self, states, actions): gamma_inv = 1 / (1 - self._gamma) @@ -195,7 +194,6 @@ class CSCAgent(): self._lambda -= gradient return gradient - def update(self, buffer, avg_failures, total_episodes): self._avg_failures = avg_failures self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures) @@ -220,6 +218,8 @@ class CSCAgent(): self._writer.add_scalar(f"agent/safety_loss", round(sloss,4), total_episodes) self._writer.add_scalar(f"agent/lambda_gradient", round(lgradient,4), total_episodes) + def after_updates(self): + self.soft_update(self._policy_old, self._policy, tau=1) @torch.no_grad def sample(self, state, shielded=True):