From 2013b9f9baf1bc1745a45f3d92f6b82ad1905e6b Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Sun, 22 Dec 2024 15:34:34 +0100 Subject: [PATCH] Small adjustments --- src/buffer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/buffer.py b/src/buffer.py index 19df975..2452601 100644 --- a/src/buffer.py +++ b/src/buffer.py @@ -2,7 +2,7 @@ import numpy as np class ReplayBuffer(): def __init__(self, env, cap): - self._cap = cap + self._cap = max(1,cap) self._size = 0 self._ptr = 0 @@ -12,7 +12,6 @@ class ReplayBuffer(): self._costs = np.zeros((cap, ), dtype=np.float64) self._next_states = np.zeros_like(self._states) - def _add(self, state, action, reward, cost, next_state, start, end): self._states[start:end] = state self._actions[start:end] = action @@ -22,7 +21,6 @@ class ReplayBuffer(): self._ptr = end self._size = min(self._size + (end - start), self._cap) - def add(self, state, action, reward, cost, next_state): b = state.shape[0] if self._ptr + b <= self._cap: @@ -32,7 +30,6 @@ class ReplayBuffer(): self._add(state[:d], action[:d], reward[:d], cost[:d], next_state[:d], start=self._ptr, end=self._cap) self._add(state[d:], action[d:], reward[d:], cost[d:], next_state[d:], start=0, end=b-d) - def sample(self, n): idxs = np.random.randint(low=0, high=self._size, size=n) return self._states[idxs], self._actions[idxs], self._rewards[idxs], \ @@ -40,8 +37,4 @@ class ReplayBuffer(): def clear(self): self._size = 0 - self._ptr = 0 - - @property - def size(self): - return self._size \ No newline at end of file + self._ptr = 0 \ No newline at end of file -- GitLab