diff --git a/src/buffer.py b/src/buffer.py index 0972609cd546cf3a87a0a58c9fbff714ec65b96e..19df97593fc3ddd2b042b34db4cb815b82e99a70 100644 --- a/src/buffer.py +++ b/src/buffer.py @@ -1,17 +1,16 @@ -import numpy -import gymnasium +import numpy as np class ReplayBuffer(): - def __init__(self, env:gymnasium.Env, cap): + def __init__(self, env, cap): self._cap = cap self._size = 0 self._ptr = 0 - self._states = numpy.zeros((cap, env.observation_space.shape[-1]), dtype=numpy.float64) - self._actions = numpy.zeros((cap, env.action_space.shape[-1]), dtype=numpy.float64) - self._rewards = numpy.zeros((cap, ), dtype=numpy.float64) - self._costs = numpy.zeros((cap, ), dtype=numpy.float64) - self._next_states = numpy.zeros_like(self._states) + self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64) + self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64) + self._rewards = np.zeros((cap, ), dtype=np.float64) + self._costs = np.zeros((cap, ), dtype=np.float64) + self._next_states = np.zeros_like(self._states) def _add(self, state, action, reward, cost, next_state, start, end): @@ -35,7 +34,7 @@ class ReplayBuffer(): def sample(self, n): - idxs = numpy.random.randint(low=0, high=self._size, size=n) + idxs = np.random.randint(low=0, high=self._size, size=n) return self._states[idxs], self._actions[idxs], self._rewards[idxs], \ self._costs[idxs], self._next_states[idxs]