diff --git a/src/buffer.py b/src/buffer.py index 2452601d282de134723ce6383259400c811b1db3..6b72adbb5bc8264b4db5d9144b4f9bfabb3e686d 100644 --- a/src/buffer.py +++ b/src/buffer.py @@ -1,40 +1,68 @@ import numpy as np class ReplayBuffer(): + """ + Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. Handles batched experiences. + """ def __init__(self, env, cap): self._cap = max(1,cap) - self._size = 0 - self._ptr = 0 + self._size = 0 # number of experiences in the buffer + self._ptr = 0 # pointer to the next available slot in the buffer self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64) self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64) self._rewards = np.zeros((cap, ), dtype=np.float64) self._costs = np.zeros((cap, ), dtype=np.float64) self._next_states = np.zeros_like(self._states) - - def _add(self, state, action, reward, cost, next_state, start, end): + self._dones = np.zeros((cap, ), dtype=np.uint8) + + + def _add(self, state, action, reward, cost, next_state, done, start, end): self._states[start:end] = state self._actions[start:end] = action self._rewards[start:end] = reward self._costs[start:end] = cost self._next_states[start:end] = next_state - self._ptr = end - self._size = min(self._size + (end - start), self._cap) + self._dones[start:end] = done - def add(self, state, action, reward, cost, next_state): - b = state.shape[0] - if self._ptr + b <= self._cap: - self._add(state, action, reward, cost, next_state, start=self._ptr, end=self._ptr+b) + + def add(self, state, action, reward, cost, next_state, done): + """ + Adds experiences to the buffer. Assumes batched experiences. + """ + n = state.shape[0] # NOTE: n should be less than or equal to the buffer capacity + idx_start = self._ptr + idx_end = self._ptr + n + + # if the buffer has capacity, add the experiences to the end of the buffer + if idx_end <= self._cap: + self._add(state, action, reward, cost, next_state, done, idx_start, idx_end) + + # if the buffer does not have capacity, add the experiences to the end of the buffer and wrap around else: - d = self._cap - self._ptr - self._add(state[:d], action[:d], reward[:d], cost[:d], next_state[:d], start=self._ptr, end=self._cap) - self._add(state[d:], action[d:], reward[d:], cost[d:], next_state[d:], start=0, end=b-d) + k = self._cap - idx_start + idx_end = n - k + self._add(state[:k], action[:k], reward[:k], cost[:k], next_state[:k], done[:k], start=idx_start, end=self._cap) + self._add(state[k:], action[k:], reward[k:], cost[k:], next_state[k:], done[k:], start=0, end=idx_end) + + # update the buffer size and pointer + self._ptr = idx_end + if self._size < self._cap: + self._size = min(self._cap, self._size + n) + def sample(self, n): + """ + Samples n experiences from the buffer. + """ idxs = np.random.randint(low=0, high=self._size, size=n) return self._states[idxs], self._actions[idxs], self._rewards[idxs], \ - self._costs[idxs], self._next_states[idxs] + self._costs[idxs], self._next_states[idxs], self._dones[idxs] + def clear(self): + """ + Clears the buffer. + """ self._size = 0 self._ptr = 0 \ No newline at end of file