From 1c7e57471ab7b270f6d10db46ed2a1f431f8edca Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Mon, 30 Sep 2024 00:14:45 +0200 Subject: [PATCH] Renaming --- src/buffer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/buffer.py b/src/buffer.py index 0972609..19df975 100644 --- a/src/buffer.py +++ b/src/buffer.py @@ -1,17 +1,16 @@ -import numpy -import gymnasium +import numpy as np class ReplayBuffer(): - def __init__(self, env:gymnasium.Env, cap): + def __init__(self, env, cap): self._cap = cap self._size = 0 self._ptr = 0 - self._states = numpy.zeros((cap, env.observation_space.shape[-1]), dtype=numpy.float64) - self._actions = numpy.zeros((cap, env.action_space.shape[-1]), dtype=numpy.float64) - self._rewards = numpy.zeros((cap, ), dtype=numpy.float64) - self._costs = numpy.zeros((cap, ), dtype=numpy.float64) - self._next_states = numpy.zeros_like(self._states) + self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64) + self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64) + self._rewards = np.zeros((cap, ), dtype=np.float64) + self._costs = np.zeros((cap, ), dtype=np.float64) + self._next_states = np.zeros_like(self._states) def _add(self, state, action, reward, cost, next_state, start, end): @@ -35,7 +34,7 @@ class ReplayBuffer(): def sample(self, n): - idxs = numpy.random.randint(low=0, high=self._size, size=n) + idxs = np.random.randint(low=0, high=self._size, size=n) return self._states[idxs], self._actions[idxs], self._rewards[idxs], \ self._costs[idxs], self._next_states[idxs] -- GitLab