diff --git a/main.py b/main.py index a32e903157287163fb07253f67282d58e4646247..8215e83da8493c2b95087fe51331d78bf901b928 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from src.stats import Statistics from src.buffer import ReplayBuffer -from src.cql_sac.agent import CQLSAC +from src.cql_sac.agent import CSCCQLSAC from src.environment import create_environment ################## @@ -64,7 +64,7 @@ def cmd_args(): # cql args cql_args = parser.add_argument_group('CQL') - cql_args.add_argument("--cql_with_lagrange", action="store", type=float, default=0, metavar="N", + cql_args.add_argument("--cql_with_lagrange", action="store_true", default=False, help="") cql_args.add_argument("--cql_temp", action="store", type=float, default=1.0, metavar="N", help="") @@ -121,7 +121,7 @@ def setup(args): env = create_environment(args=args) buffer = ReplayBuffer(env=env, cap=args.buffer_capacity) stats = Statistics(writer=writer) - agent = CQLSAC(env=env, args=args, stats=stats) + agent = CSCCQLSAC(env=env, args=args, stats=stats) return env, agent, buffer, stats @@ -195,6 +195,8 @@ def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train stats.log_tensorboard("train/steps", episode_steps[done_masked], ticks) stats.log_tensorboard("train/unsafe", (episode_cost[done_masked] > args.cost_limit).astype(np.uint8), ticks) + stats.log_train_history((episode_cost[done_masked] > args.cost_limit).astype(np.uint8)) + stats.total_train_episodes += done_masked_count stats.total_train_steps += episode_steps[done_masked].sum() stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum() @@ -305,8 +307,8 @@ def main(args, env, agent, buffer, stats): buffer.clear() # Test loop (shielded and unshielded) - for shielded in [True, False]: - run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded) + # for shielded in [True, False]: + run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=True) if __name__ == '__main__': args = cmd_args() diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py index 85abbdca76bdf71552c718fe603f60bd645c5662..3e6e4cb0cef61d21173db14de1c7c44b4b89f447 100644 --- a/src/cql_sac/agent.py +++ b/src/cql_sac/agent.py @@ -9,7 +9,7 @@ import math import copy -class CQLSAC(nn.Module): +class CSCCQLSAC(nn.Module): """Interacts with and learns from the environment.""" def __init__(self, @@ -24,7 +24,9 @@ class CQLSAC(nn.Module): env : the vector environment args : the argparse arguments """ - super(CQLSAC, self).__init__() + super(CSCCQLSAC, self).__init__() + self.stats = stats + state_size = env.observation_space.shape[-1] action_size = env.action_space.shape[-1] hidden_size = args.hidden_size @@ -39,7 +41,7 @@ class CQLSAC(nn.Module): self.target_entropy = -action_size # -dim(A) - self.log_alpha = torch.tensor([0.0], requires_grad=True) + self.log_alpha = torch.tensor([0.0], requires_grad=True, device=self.device) self.alpha = self.log_alpha.exp().detach() self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=self.learning_rate) @@ -48,19 +50,41 @@ class CQLSAC(nn.Module): self.cql_temp = args.cql_temp self.cql_weight = args.cql_weight self.cql_target_action_gap = args.cql_target_action_gap - self.cql_log_alpha = torch.zeros(1, requires_grad=True) + self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate) + + # CSC params + self.csc_shield_iterations = 100 + self.csc_alpha = args.csc_alpha + self.csc_beta = args.csc_beta + self.csc_delta = args.csc_delta + self.csc_chi = args.csc_chi + self.csc_avg_unsafe = args.csc_chi + + self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device) + self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) # Actor Network self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device) - self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=self.learning_rate) + self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=self.learning_rate) + + # Safety Critic Network (w/ Target Network) + self.safety_critic1 = Critic(state_size, action_size, hidden_size).to(self.device) + self.safety_critic2 = Critic(state_size, action_size, hidden_size).to(self.device) + + self.safety_critic1_target = Critic(state_size, action_size, hidden_size).to(self.device) + self.safety_critic1_target.load_state_dict(self.safety_critic1.state_dict()) + + self.safety_critic2_target = Critic(state_size, action_size, hidden_size).to(self.device) + self.safety_critic2_target.load_state_dict(self.safety_critic2.state_dict()) + + self.safety_critic1_optimizer = optim.Adam(self.safety_critic1.parameters(), lr=self.learning_rate) + self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate) # Critic Network (w/ Target Network) self.critic1 = Critic(state_size, action_size, hidden_size).to(self.device) self.critic2 = Critic(state_size, action_size, hidden_size).to(self.device) - assert self.critic1.parameters() != self.critic2.parameters() - self.critic1_target = Critic(state_size, action_size, hidden_size).to(self.device) self.critic1_target.load_state_dict(self.critic1.state_dict()) @@ -72,32 +96,54 @@ class CQLSAC(nn.Module): def get_action(self, state, eval=False): - """Returns actions for given state as per current policy.""" - state = torch.from_numpy(state).float().to(self.device) + """ + Returns shielded actions for given state as per current policy. - with torch.no_grad(): - if eval: - action = self.actor_local.get_det_action(state) - else: - action = self.actor_local.get_action(state) - return action + Note: eval is currently ignored. + """ + state = torch.from_numpy(state).float().to(self.device) + + batch_size = state.shape[0] + unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe) + unsafety_best = torch.full((batch_size, ), fill_value=unsafety_threshold+1).to(self.device) + action_best = torch.zeros(batch_size, self.action_size).to(self.device) + + # Run at max 'csc_shield_iterations' iterations to find safe action + for _ in range(self.csc_shield_iterations): + # If all actions are already safe, break + mask_safe = unsafety_best <= unsafety_threshold + if mask_safe.all(): break + + # Sample new actions + with torch.no_grad(): + action = self.actor_local.get_action(state).to(self.device) + + # Estimate safety of new actions + q1 = self.safety_critic1(state, action) + q2 = self.safety_critic2(state, action) + unsafety = torch.min(q1, q2).squeeze(1) + + # Update best actions if they are still unsafe and new actions are safer + mask_update = (~mask_safe) & (unsafety < unsafety_best) + unsafety_best[mask_update] = unsafety[mask_update] + action_best[mask_update] = action[mask_update] + + return action_best def calc_policy_loss(self, states, alpha): actions_pred, log_pis = self.actor_local.evaluate(states) q1 = self.critic1(states, actions_pred.squeeze(0)) q2 = self.critic2(states, actions_pred.squeeze(0)) - min_Q = torch.min(q1,q2).cpu() - actor_loss = ((alpha * log_pis.cpu() - min_Q )).mean() + min_Q = torch.min(q1,q2) + actor_loss = ((alpha * log_pis - min_Q )).mean() return actor_loss, log_pis def _compute_policy_values(self, obs_pi, obs_q): #with torch.no_grad(): actions_pred, log_pis = self.actor_local.evaluate(obs_pi) - - qs1 = self.critic1(obs_q, actions_pred) - qs2 = self.critic2(obs_q, actions_pred) - + qs1 = self.safety_critic1(obs_q, actions_pred) + qs2 = self.safety_critic2(obs_q, actions_pred) return qs1 - log_pis.detach(), qs2 - log_pis.detach() def _compute_random_values(self, obs, actions, critic): @@ -118,6 +164,7 @@ class CQLSAC(nn.Module): experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples gamma (float): discount factor """ + self.csc_avg_unsafe = self.stats.train_unsafe_avg states, actions, rewards, costs, next_states, dones = experiences states = torch.from_numpy(states).float().to(self.device) @@ -127,49 +174,63 @@ class CQLSAC(nn.Module): next_states = torch.from_numpy(next_states).float().to(self.device) dones = torch.from_numpy(dones).float().to(self.device).view(-1, 1) - # ---------------------------- update actor ---------------------------- # - current_alpha = copy.deepcopy(self.alpha) - actor_loss, log_pis = self.calc_policy_loss(states, current_alpha) - self.actor_optimizer.zero_grad() - actor_loss.backward() - self.actor_optimizer.step() - - # Compute alpha loss - alpha_loss = - (self.log_alpha.exp() * (log_pis.cpu() + self.target_entropy).detach().cpu()).mean() - self.alpha_optimizer.zero_grad() - alpha_loss.backward() - self.alpha_optimizer.step() - self.alpha = self.log_alpha.exp().detach() - # ---------------------------- update critic ---------------------------- # # Get predicted next-state actions and Q values from target models with torch.no_grad(): next_action, new_log_pi = self.actor_local.evaluate(next_states) Q_target1_next = self.critic1_target(next_states, next_action) Q_target2_next = self.critic2_target(next_states, next_action) - Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha.to(self.device) * new_log_pi + Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha * new_log_pi # Compute Q targets for current states (y_i) Q_targets = rewards + (self.gamma * (1 - dones) * Q_target_next) - # Compute critic loss q1 = self.critic1(states, actions) q2 = self.critic2(states, actions) critic1_loss = F.mse_loss(q1, Q_targets) critic2_loss = F.mse_loss(q2, Q_targets) + + # Update critics + # critic 1 + self.critic1_optimizer.zero_grad() + critic1_loss.backward(retain_graph=True) + clip_grad_norm_(self.critic1.parameters(), self.clip_grad_param) + self.critic1_optimizer.step() + # critic 2 + self.critic2_optimizer.zero_grad() + critic2_loss.backward() + clip_grad_norm_(self.critic2.parameters(), self.clip_grad_param) + self.critic2_optimizer.step() + + # ---------------------------- update safety critic ---------------------------- # + # Get predicted next-state actions and Q values from target models + with torch.no_grad(): + next_action, new_log_pi = self.actor_local.evaluate(next_states) + Q_target1_next = self.safety_critic1_target(next_states, next_action) + Q_target2_next = self.safety_critic2_target(next_states, next_action) + Q_target_next = torch.min(Q_target1_next, Q_target2_next) # - self.alpha * new_log_pi + # Compute Q targets for current states (y_i) + Q_targets = costs + (self.gamma * (1 - dones) * Q_target_next) + + # Compute safety_critic loss + q1 = self.safety_critic1(states, actions) + q2 = self.safety_critic2(states, actions) + + safety_critic1_loss = F.mse_loss(q1, Q_targets) + safety_critic2_loss = F.mse_loss(q2, Q_targets) # CQL addon - random_actions = torch.FloatTensor(q1.shape[0] * 10, actions.shape[-1]).uniform_(-1, 1).to(self.device) - num_repeat = int (random_actions.shape[0] / states.shape[0]) + num_repeat = 10 + random_actions = torch.FloatTensor(q1.shape[0] * num_repeat, actions.shape[-1]).uniform_(-1, 1).to(self.device) temp_states = states.unsqueeze(1).repeat(1, num_repeat, 1).view(states.shape[0] * num_repeat, states.shape[1]) temp_next_states = next_states.unsqueeze(1).repeat(1, num_repeat, 1).view(next_states.shape[0] * num_repeat, next_states.shape[1]) current_pi_values1, current_pi_values2 = self._compute_policy_values(temp_states, temp_states) next_pi_values1, next_pi_values2 = self._compute_policy_values(temp_next_states, temp_states) - random_values1 = self._compute_random_values(temp_states, random_actions, self.critic1).reshape(states.shape[0], num_repeat, 1) - random_values2 = self._compute_random_values(temp_states, random_actions, self.critic2).reshape(states.shape[0], num_repeat, 1) + random_values1 = self._compute_random_values(temp_states, random_actions, self.safety_critic1).reshape(states.shape[0], num_repeat, 1) + random_values2 = self._compute_random_values(temp_states, random_actions, self.safety_critic2).reshape(states.shape[0], num_repeat, 1) current_pi_values1 = current_pi_values1.reshape(states.shape[0], num_repeat, 1) current_pi_values2 = current_pi_values2.reshape(states.shape[0], num_repeat, 1) @@ -183,14 +244,14 @@ class CQLSAC(nn.Module): assert cat_q1.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q1 instead has shape: {cat_q1.shape}" assert cat_q2.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q2 instead has shape: {cat_q2.shape}" - - cql1_scaled_loss = ((torch.logsumexp(cat_q1 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) - q1.mean()) * self.cql_weight - cql2_scaled_loss = ((torch.logsumexp(cat_q2 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) - q2.mean()) * self.cql_weight + # flipped sign of cql1_scaled_loss and cql2_scaled_loss + cql1_scaled_loss = -(torch.logsumexp(cat_q1 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q1.mean() * self.cql_weight) + cql2_scaled_loss = -(torch.logsumexp(cat_q2 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q2.mean() * self.cql_weight) cql_alpha_loss = torch.FloatTensor([0.0]) - cql_alpha = torch.FloatTensor([0.0]) + cql_alpha = torch.FloatTensor([1.0]) if self.cql_with_lagrange: - cql_alpha = torch.clamp(self.cql_log_alpha.exp(), min=0.0, max=1000000.0).to(self.device) + cql_alpha = torch.clamp(self.cql_log_alpha.exp(), min=0.0, max=1000000.0) cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.cql_target_action_gap) cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.cql_target_action_gap) @@ -199,27 +260,95 @@ class CQLSAC(nn.Module): cql_alpha_loss.backward(retain_graph=True) self.cql_alpha_optimizer.step() - total_c1_loss = critic1_loss + cql1_scaled_loss - total_c2_loss = critic2_loss + cql2_scaled_loss + total_c1_loss = safety_critic1_loss + cql1_scaled_loss + total_c2_loss = safety_critic2_loss + cql2_scaled_loss - # Update critics - # critic 1 - self.critic1_optimizer.zero_grad() + # Update safety_critics + # safety_critic 1 + self.safety_critic1_optimizer.zero_grad() total_c1_loss.backward(retain_graph=True) - clip_grad_norm_(self.critic1.parameters(), self.clip_grad_param) - self.critic1_optimizer.step() - # critic 2 - self.critic2_optimizer.zero_grad() + clip_grad_norm_(self.safety_critic1.parameters(), self.clip_grad_param) + self.safety_critic1_optimizer.step() + # safety_critic 2 + self.safety_critic2_optimizer.zero_grad() total_c2_loss.backward() - clip_grad_norm_(self.critic2.parameters(), self.clip_grad_param) - self.critic2_optimizer.step() + clip_grad_norm_(self.safety_critic2.parameters(), self.clip_grad_param) + self.safety_critic2_optimizer.step() + + # ---------------------------- update csc lambda ---------------------------- # + # Estimate cost advantage + with torch.no_grad(): + q1 = self.safety_critic1(states, actions) + q2 = self.safety_critic2(states, actions) + v = torch.min(q1, q2) + + new_action, new_log_pi = self.actor_local.evaluate(states) + q1 = self.safety_critic1(states, new_action) + q2 = self.safety_critic2(states, new_action) + q = torch.min(q1, q2) + + cost_advantage = (q - v).mean() + + # Compute csc lambda loss + csc_lambda_loss = -self.csc_lambda*(self.csc_avg_unsafe + (1 / (1 - self.gamma)) * cost_advantage - self.csc_chi) + + self.csc_lambda_optimizer.zero_grad() + csc_lambda_loss.backward() + self.csc_lambda_optimizer.step() + + # ---------------------------- update actor ---------------------------- # + # Estimate reward advantage + q1 = self.critic1(states, actions) + q2 = self.critic2(states, actions) + v = torch.min(q1, q2).detach() + + new_action, new_log_pi = self.actor_local.evaluate(states) + q1 = self.critic1(states, new_action) + q2 = self.critic2(states, new_action) + q = torch.min(q1, q2) + + reward_advantage = q - v + + # Optimize actor + actor_loss = ((self.alpha * new_log_pi - reward_advantage)).mean() + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + # Compute alpha loss + alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + self.alpha = self.log_alpha.exp().detach() # ----------------------- update target networks ----------------------- # self.soft_update(self.critic1, self.critic1_target) self.soft_update(self.critic2, self.critic2_target) + self.soft_update(self.safety_critic1, self.safety_critic1_target) + self.soft_update(self.safety_critic2, self.safety_critic2_target) - return actor_loss.item(), alpha_loss.item(), critic1_loss.item(), critic2_loss.item(), cql1_scaled_loss.item(), cql2_scaled_loss.item(), current_alpha, cql_alpha_loss.item(), cql_alpha.item() + # ----------------------- update stats ----------------------- # + data = { + "actor_loss": actor_loss.item(), + "alpha_loss": alpha_loss.item(), + "alpha": self.alpha.item(), + "lambda_loss": csc_lambda_loss.item(), + "lambda": self.csc_lambda.item(), + "critic1_loss": critic1_loss.item(), + "critic2_loss": critic2_loss.item(), + "cql1_scaled_loss": cql1_scaled_loss.item(), + "cql2_scaled_loss": cql2_scaled_loss.item(), + "total_c1_loss": total_c1_loss.item(), + "total_c2_loss": total_c2_loss.item(), + "cql_alpha_loss": cql_alpha_loss.item(), + "cql_alpha": cql_alpha.item() + } + if self.stats.total_updates % 8 == 0: + self.stats.log_update_tensorboard(data) + self.stats.total_updates += 1 + return data def soft_update(self, local_model , target_model): """Soft update model parameters. diff --git a/src/stats.py b/src/stats.py index 8b7494dec164ecaff2614582eb49da9c343b9aef..4f263bf73e5ef7fb25774d047297b8fcbfce7bf9 100644 --- a/src/stats.py +++ b/src/stats.py @@ -46,6 +46,11 @@ class Statistics: self.test_reward_history.clear() self.test_cost_history.clear() self.test_unsafe_history.clear() + + def log_update_tensorboard(self, data:dict): + for k, v in data.items(): + name = f"update/{k}" + self.writer.add_scalar(name, v, self.total_updates) def log_tensorboard(self, name: str, values: np.ndarray, ticks: np.ndarray): if values.size == 1: