diff --git a/src/networks.py b/src/networks.py index fbd1072c52531163103d82d90dc8dea8daf197e3..d0fed82aa4687a8110a3d4bee90379a42c779ac8 100644 --- a/src/networks.py +++ b/src/networks.py @@ -42,7 +42,7 @@ class QNetwork(nn.Module): self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim) self.linear2 = nn.Linear(hidden_dim, hidden_dim) self.linear3 = nn.Linear(hidden_dim, 1) - self.last_activation = F.sigmoid if sigmoid_activation else nn.Identity + self.last_activation = F.sigmoid if sigmoid_activation else nn.Identity() self.apply(weights_init_)