Skip to content
Snippets Groups Projects
Commit c2058916 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Temporary helper, policy, networks commit

parent 7dffa991
No related branches found
No related tags found
No related merge requests found
import torch
from torch.nn import Module
@torch.no_grad
def soft_update(target:Module, source:Module, tau:float):
"""
Performs a soft parameter update of the target network.
"""
for tparam, sparam in zip(target.parameters(), source.parameters()):
tparam.data.copy_(tau * sparam.data + (1.0 - tau) * tparam.data)
@torch.no_grad
def hard_update(target:Module, source:Module):
"""
Performs a hard parameter update of the target network. Copies the source parameter's data into target.
"""
soft_update(target, source, tau=1.0)
@torch.no_grad
def apply_gradient(network:Module, lr):
"""
Uses model gradients and a learning rate to update all model parameters.
"""
for name, param in network.named_parameters():
if not param.requires_grad: continue
if param.grad is None: continue
param += lr * param.grad
\ No newline at end of file
......@@ -3,12 +3,9 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
"""
Copied from spice project.
"""
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6
# Initialize Policy weights
def weights_init_(m):
......@@ -85,14 +82,35 @@ class GaussianPolicy(nn.Module):
return mean, log_std
def distribution(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
try:
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
except ValueError:
print("state:", state)
print(mean)
print(log_std)
print(std)
print("linear1:", self.linear1)
print(self.linear2)
print(self.mean_linear)
print(self.log_std_linear)
print("bias:", self.action_bias)
print(self.action_scale)
exit(0)
return normal
def log_prob(self, state, action):
dist = self.distribution(state)
return dist.log_prob(action)
y_t = (action - self.action_bias) / self.action_scale
x_t = torch.atanh(y_t)
log_prob = dist.log_prob(x_t)
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
log_prob = log_prob.sum(axis=-1, keepdim=True)
return log_prob
def sample(self, state, num_samples=1):
normal = self.distribution(state)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment