diff --git a/src/helper.py b/src/helper.py deleted file mode 100644 index 05ea49ade13354cb86cdf44801e154a1b90e54cc..0000000000000000000000000000000000000000 --- a/src/helper.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -from torch.nn import Module - -@torch.no_grad -def soft_update(target:Module, source:Module, tau:float): - """ - Performs a soft parameter update of the target network. - """ - for tparam, sparam in zip(target.parameters(), source.parameters()): - tparam.data.copy_(tau * sparam.data + (1.0 - tau) * tparam.data) - -@torch.no_grad -def hard_update(target:Module, source:Module): - """ - Performs a hard parameter update of the target network. Copies the source parameter's data into target. - """ - soft_update(target, source, tau=1.0) - -@torch.no_grad -def apply_gradient(network:Module, lr): - """ - Uses model gradients and a learning rate to update all model parameters. - """ - for name, param in network.named_parameters(): - if not param.requires_grad: continue - if param.grad is None: continue - param += lr * param.grad \ No newline at end of file diff --git a/src/sac/__init__.py b/src/sac/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/sac/model.py b/src/sac/model.py deleted file mode 100644 index b4ab642703ff415af4be8f4fce3aee071346756f..0000000000000000000000000000000000000000 --- a/src/sac/model.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.distributions import Normal - -LOG_SIG_MAX = 2 -LOG_SIG_MIN = -20 -epsilon = 1e-6 - -# Initialize Policy weights -def weights_init_(m): - if isinstance(m, nn.Linear): - torch.nn.init.xavier_uniform_(m.weight, gain=1) - torch.nn.init.constant_(m.bias, 0) - - -class ValueNetwork(nn.Module): - def __init__(self, num_inputs, hidden_dim): - super(ValueNetwork, self).__init__() - - self.linear1 = nn.Linear(num_inputs, hidden_dim) - self.linear2 = nn.Linear(hidden_dim, hidden_dim) - self.linear3 = nn.Linear(hidden_dim, 1) - - self.apply(weights_init_) - - def forward(self, state): - x = F.relu(self.linear1(state)) - x = F.relu(self.linear2(x)) - x = self.linear3(x) - return x - - -class QNetwork(nn.Module): - def __init__(self, num_inputs, num_actions, hidden_dim): - super(QNetwork, self).__init__() - - # Q1 architecture - self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim) - self.linear2 = nn.Linear(hidden_dim, hidden_dim) - self.linear3 = nn.Linear(hidden_dim, 1) - - # Q2 architecture - self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim) - self.linear5 = nn.Linear(hidden_dim, hidden_dim) - self.linear6 = nn.Linear(hidden_dim, 1) - - self.apply(weights_init_) - - def forward(self, state, action): - xu = torch.cat([state, action], 1) - - x1 = F.relu(self.linear1(xu)) - x1 = F.relu(self.linear2(x1)) - x1 = self.linear3(x1) - - x2 = F.relu(self.linear4(xu)) - x2 = F.relu(self.linear5(x2)) - x2 = self.linear6(x2) - - return x1, x2 - - -class GaussianPolicy(nn.Module): - def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): - super(GaussianPolicy, self).__init__() - - self.linear1 = nn.Linear(num_inputs, hidden_dim) - self.linear2 = nn.Linear(hidden_dim, hidden_dim) - - self.mean_linear = nn.Linear(hidden_dim, num_actions) - self.log_std_linear = nn.Linear(hidden_dim, num_actions) - - self.apply(weights_init_) - - # action rescaling - if action_space is None: - self.action_scale = torch.tensor(1.) - self.action_bias = torch.tensor(0.) - else: - self.action_scale = torch.FloatTensor( - (action_space.high - action_space.low) / 2.) - self.action_bias = torch.FloatTensor( - (action_space.high + action_space.low) / 2.) - - def forward(self, state): - x = F.relu(self.linear1(state)) - x = F.relu(self.linear2(x)) - mean = self.mean_linear(x) - log_std = self.log_std_linear(x) - log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) - return mean, log_std - - def sample(self, state): - mean, log_std = self.forward(state) - std = log_std.exp() - normal = Normal(mean, std) - x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) - y_t = torch.tanh(x_t) - action = y_t * self.action_scale + self.action_bias - log_prob = normal.log_prob(x_t) - # Enforcing Action Bound - log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) - log_prob = log_prob.sum(1, keepdim=True) - mean = torch.tanh(mean) * self.action_scale + self.action_bias - return action, log_prob, mean - - def to(self, device): - self.action_scale = self.action_scale.to(device) - self.action_bias = self.action_bias.to(device) - return super(GaussianPolicy, self).to(device) - - -class DeterministicPolicy(nn.Module): - def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): - super(DeterministicPolicy, self).__init__() - self.linear1 = nn.Linear(num_inputs, hidden_dim) - self.linear2 = nn.Linear(hidden_dim, hidden_dim) - - self.mean = nn.Linear(hidden_dim, num_actions) - self.noise = torch.Tensor(num_actions) - - self.apply(weights_init_) - - # action rescaling - if action_space is None: - self.action_scale = 1. - self.action_bias = 0. - else: - self.action_scale = torch.FloatTensor( - (action_space.high - action_space.low) / 2.) - self.action_bias = torch.FloatTensor( - (action_space.high + action_space.low) / 2.) - - def forward(self, state): - x = F.relu(self.linear1(state)) - x = F.relu(self.linear2(x)) - mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias - return mean - - def sample(self, state): - mean = self.forward(state) - noise = self.noise.normal_(0., std=0.1) - noise = noise.clamp(-0.25, 0.25) - action = mean + noise - return action, torch.tensor(0.), mean - - def to(self, device): - self.action_scale = self.action_scale.to(device) - self.action_bias = self.action_bias.to(device) - self.noise = self.noise.to(device) - return super(DeterministicPolicy, self).to(device) diff --git a/src/sac/sac.py b/src/sac/sac.py deleted file mode 100644 index 4f774747754483ace5c02ba2ec0f9ca3fd6cc954..0000000000000000000000000000000000000000 --- a/src/sac/sac.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import torch -import torch.nn.functional as F -from torch.optim import Adam -from .utils import soft_update, hard_update -from .model import GaussianPolicy, QNetwork, DeterministicPolicy - - -class SAC(object): - def __init__(self, num_inputs, action_space, args): - - self.gamma = args.gamma - self.tau = args.tau - self.alpha = args.alpha - - self.policy_type = args.policy - self.target_update_interval = args.target_update_interval - self.automatic_entropy_tuning = args.automatic_entropy_tuning - - self.device = torch.device("cuda" if args.cuda else "cpu") - - self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device) - self.critic_optim = Adam(self.critic.parameters(), lr=args.lr) - - self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device) - hard_update(self.critic_target, self.critic) - - if self.policy_type == "Gaussian": - # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper - if self.automatic_entropy_tuning is True: - self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() - self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) - self.alpha_optim = Adam([self.log_alpha], lr=args.lr) - - self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) - self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) - - else: - self.alpha = 0 - self.automatic_entropy_tuning = False - self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) - self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) - - def select_action(self, state, evaluate=False): - state = torch.FloatTensor(state).to(self.device).unsqueeze(0) - if evaluate is False: - action, _, _ = self.policy.sample(state) - else: - _, _, action = self.policy.sample(state) - return action.detach().cpu().numpy()[0] - - def update_parameters(self, memory, batch_size, updates): - # Sample a batch from memory - state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size) - - state_batch = torch.FloatTensor(state_batch).to(self.device) - next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) - action_batch = torch.FloatTensor(action_batch).to(self.device) - reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) - mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1) - - with torch.no_grad(): - next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch) - qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) - min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi - next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target) - qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step - qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] - qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] - qf_loss = qf1_loss + qf2_loss - - self.critic_optim.zero_grad() - qf_loss.backward() - self.critic_optim.step() - - pi, log_pi, _ = self.policy.sample(state_batch) - - qf1_pi, qf2_pi = self.critic(state_batch, pi) - min_qf_pi = torch.min(qf1_pi, qf2_pi) - - policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] - - self.policy_optim.zero_grad() - policy_loss.backward() - self.policy_optim.step() - - if self.automatic_entropy_tuning: - alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() - - self.alpha_optim.zero_grad() - alpha_loss.backward() - self.alpha_optim.step() - - self.alpha = self.log_alpha.exp() - alpha_tlogs = self.alpha.clone() # For TensorboardX logs - else: - alpha_loss = torch.tensor(0.).to(self.device) - alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs - - - if updates % self.target_update_interval == 0: - soft_update(self.critic_target, self.critic, self.tau) - - return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item() - - # Save model parameters - def save_checkpoint(self, env_name, suffix="", ckpt_path=None): - if not os.path.exists('checkpoints/'): - os.makedirs('checkpoints/') - if ckpt_path is None: - ckpt_path = "checkpoints/sac_checkpoint_{}_{}".format(env_name, suffix) - print('Saving models to {}'.format(ckpt_path)) - torch.save({'policy_state_dict': self.policy.state_dict(), - 'critic_state_dict': self.critic.state_dict(), - 'critic_target_state_dict': self.critic_target.state_dict(), - 'critic_optimizer_state_dict': self.critic_optim.state_dict(), - 'policy_optimizer_state_dict': self.policy_optim.state_dict()}, ckpt_path) - - # Load model parameters - def load_checkpoint(self, ckpt_path, evaluate=False): - print('Loading models from {}'.format(ckpt_path)) - if ckpt_path is not None: - checkpoint = torch.load(ckpt_path) - self.policy.load_state_dict(checkpoint['policy_state_dict']) - self.critic.load_state_dict(checkpoint['critic_state_dict']) - self.critic_target.load_state_dict(checkpoint['critic_target_state_dict']) - self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict']) - self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict']) - - if evaluate: - self.policy.eval() - self.critic.eval() - self.critic_target.eval() - else: - self.policy.train() - self.critic.train() - self.critic_target.train() - diff --git a/src/sac/utils.py b/src/sac/utils.py deleted file mode 100644 index 038ceb449cfa0ad15eca621d2ca4ca980780a133..0000000000000000000000000000000000000000 --- a/src/sac/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import math -import torch - -def create_log_gaussian(mean, log_std, t): - quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2)) - l = mean.shape - log_z = log_std - z = l[-1] * math.log(2 * math.pi) - log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z - return log_p - -def logsumexp(inputs, dim=None, keepdim=False): - if dim is None: - inputs = inputs.view(-1) - dim = 0 - s, _ = torch.max(inputs, dim=dim, keepdim=True) - outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() - if not keepdim: - outputs = outputs.squeeze(dim) - return outputs - -def soft_update(target, source, tau): - for target_param, param in zip(target.parameters(), source.parameters()): - target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) - -def hard_update(target, source): - for target_param, param in zip(target.parameters(), source.parameters()): - target_param.data.copy_(param.data)