diff --git a/src/networks.py b/src/networks.py index d0fed82aa4687a8110a3d4bee90379a42c779ac8..0ba8102654787ff21902f205cb1a2bad7f4e1620 100644 --- a/src/networks.py +++ b/src/networks.py @@ -9,7 +9,6 @@ Copied from spice project. LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 -epsilon = 1e-6 # Initialize Policy weights def weights_init_(m): @@ -47,14 +46,12 @@ class QNetwork(nn.Module): self.apply(weights_init_) def forward(self, state, action): - xu = torch.cat([state, action], -1) - - x1 = F.relu(self.linear1(xu)) - x1 = F.relu(self.linear2(x1)) - x1 = self.linear3(x1) - x1 = self.last_activation(x1) - - return x1 + x = torch.cat([state, action], -1) + x = F.relu(self.linear1(x)) + x = F.relu(self.linear2(x)) + x = self.linear3(x) + x = self.last_activation(x) + return x class GaussianPolicy(nn.Module): @@ -74,15 +71,11 @@ class GaussianPolicy(nn.Module): self.action_scale = torch.tensor(1.) self.action_bias = torch.tensor(0.) else: - # NOTE: we assume all envs in the vector have the same action sapce self.action_scale = torch.FloatTensor( (action_space.high[0] - action_space.low[0]) / 2.) self.action_bias = torch.FloatTensor( (action_space.high[0] + action_space.low[0]) / 2.) - def __call__(self, state, action): - return self.log_prob(state, action) - def forward(self, state): x = F.relu(self.linear1(state)) x = F.relu(self.linear2(x)) @@ -94,7 +87,7 @@ class GaussianPolicy(nn.Module): def distribution(self, state): mean, log_std = self.forward(state) std = log_std.exp() - normal = Normal(mean, std, validate_args=False) + normal = Normal(mean, std) return normal def log_prob(self, state, action): @@ -102,9 +95,7 @@ class GaussianPolicy(nn.Module): return dist.log_prob(action) def sample(self, state, num_samples=1): - mean, log_std = self.forward(state) - std = log_std.exp() - normal = Normal(mean, std) + normal = self.distribution(state) x_t = normal.rsample((num_samples, )) # for reparameterization trick (mean + std * N(0,1)) y_t = torch.tanh(x_t) action = y_t * self.action_scale + self.action_bias diff --git a/src/stats.py b/src/stats.py index cd59501739397d968c86a7565e2d73baa0e2fbc0..bd886d969da5dda35968e6c15a7baa7d4e6885b9 100644 --- a/src/stats.py +++ b/src/stats.py @@ -6,42 +6,35 @@ class Statistics(): self.total_failures = 0 self.total_steps = 0 self.total_updates = 0 - - self.time_avg_train = 0 - self.train_count = 0 - self._time_train = None - - self.time_avg_update = 0 - self.update_count = 0 - self._time_update = None - self.time_avg_test = 0 - self.test_count = 0 - self._time_test = None + for name in ['train', 'update', 'test']: + setattr(self, f"time_avg_{name}", 0) + setattr(self, f"{name}_count", 0) + setattr(self, f"_time_{name}", None) - self.time_start = time.time() + self._time_start = time.time() def total_time(self, name=None): if name is None: - t = time.time() - self.time_start + t = time.time() - self._time_start else: - t = self.__dict__[f"time_avg_{name}"] * self.__dict__[f"{name}_count"] + t = getattr(self, f"time_avg_{name}") * getattr(self, f"{name}_count") return round(t,2) def begin(self, name): - self.__dict__[f"_time_{name}"] = time.time() + setattr(self, f"_time_{name}", time.time()) def end(self, name): t = time.time() - d = t - self.__dict__[f"_time_{name}"] - self.__dict__[f"{name}_count"] += 1 - c = self.__dict__[f"{name}_count"] + d = t - getattr(self, f"_time_{name}") + setattr(self, f"{name}_count", getattr(self, f"{name}_count") + 1) + c = getattr(self, f"{name}_count") a = 1/c b = 1 - a - self.__dict__[f"time_avg_{name}"] = a*d + b*self.__dict__[f"time_avg_{name}"] + setattr(self, f"time_avg_{name}", a*d + b*getattr(self, f"time_avg_{name}")) def print(self):