diff --git a/src/buffer.py b/src/buffer.py index 19df97593fc3ddd2b042b34db4cb815b82e99a70..2452601d282de134723ce6383259400c811b1db3 100644 --- a/src/buffer.py +++ b/src/buffer.py @@ -2,7 +2,7 @@ import numpy as np class ReplayBuffer(): def __init__(self, env, cap): - self._cap = cap + self._cap = max(1,cap) self._size = 0 self._ptr = 0 @@ -12,7 +12,6 @@ class ReplayBuffer(): self._costs = np.zeros((cap, ), dtype=np.float64) self._next_states = np.zeros_like(self._states) - def _add(self, state, action, reward, cost, next_state, start, end): self._states[start:end] = state self._actions[start:end] = action @@ -22,7 +21,6 @@ class ReplayBuffer(): self._ptr = end self._size = min(self._size + (end - start), self._cap) - def add(self, state, action, reward, cost, next_state): b = state.shape[0] if self._ptr + b <= self._cap: @@ -32,7 +30,6 @@ class ReplayBuffer(): self._add(state[:d], action[:d], reward[:d], cost[:d], next_state[:d], start=self._ptr, end=self._cap) self._add(state[d:], action[d:], reward[d:], cost[d:], next_state[d:], start=0, end=b-d) - def sample(self, n): idxs = np.random.randint(low=0, high=self._size, size=n) return self._states[idxs], self._actions[idxs], self._rewards[idxs], \ @@ -40,8 +37,4 @@ class ReplayBuffer(): def clear(self): self._size = 0 - self._ptr = 0 - - @property - def size(self): - return self._size \ No newline at end of file + self._ptr = 0 \ No newline at end of file