Skip to content
Snippets Groups Projects
Commit 88376f26 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Updated buffer

parent 2013b9f9
No related branches found
No related tags found
No related merge requests found
import numpy as np
class ReplayBuffer():
"""
Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. Handles batched experiences.
"""
def __init__(self, env, cap):
self._cap = max(1,cap)
self._size = 0
self._ptr = 0
self._size = 0 # number of experiences in the buffer
self._ptr = 0 # pointer to the next available slot in the buffer
self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64)
self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64)
self._rewards = np.zeros((cap, ), dtype=np.float64)
self._costs = np.zeros((cap, ), dtype=np.float64)
self._next_states = np.zeros_like(self._states)
def _add(self, state, action, reward, cost, next_state, start, end):
self._dones = np.zeros((cap, ), dtype=np.uint8)
def _add(self, state, action, reward, cost, next_state, done, start, end):
self._states[start:end] = state
self._actions[start:end] = action
self._rewards[start:end] = reward
self._costs[start:end] = cost
self._next_states[start:end] = next_state
self._ptr = end
self._size = min(self._size + (end - start), self._cap)
self._dones[start:end] = done
def add(self, state, action, reward, cost, next_state):
b = state.shape[0]
if self._ptr + b <= self._cap:
self._add(state, action, reward, cost, next_state, start=self._ptr, end=self._ptr+b)
def add(self, state, action, reward, cost, next_state, done):
"""
Adds experiences to the buffer. Assumes batched experiences.
"""
n = state.shape[0] # NOTE: n should be less than or equal to the buffer capacity
idx_start = self._ptr
idx_end = self._ptr + n
# if the buffer has capacity, add the experiences to the end of the buffer
if idx_end <= self._cap:
self._add(state, action, reward, cost, next_state, done, idx_start, idx_end)
# if the buffer does not have capacity, add the experiences to the end of the buffer and wrap around
else:
d = self._cap - self._ptr
self._add(state[:d], action[:d], reward[:d], cost[:d], next_state[:d], start=self._ptr, end=self._cap)
self._add(state[d:], action[d:], reward[d:], cost[d:], next_state[d:], start=0, end=b-d)
k = self._cap - idx_start
idx_end = n - k
self._add(state[:k], action[:k], reward[:k], cost[:k], next_state[:k], done[:k], start=idx_start, end=self._cap)
self._add(state[k:], action[k:], reward[k:], cost[k:], next_state[k:], done[k:], start=0, end=idx_end)
# update the buffer size and pointer
self._ptr = idx_end
if self._size < self._cap:
self._size = min(self._cap, self._size + n)
def sample(self, n):
"""
Samples n experiences from the buffer.
"""
idxs = np.random.randint(low=0, high=self._size, size=n)
return self._states[idxs], self._actions[idxs], self._rewards[idxs], \
self._costs[idxs], self._next_states[idxs]
self._costs[idxs], self._next_states[idxs], self._dones[idxs]
def clear(self):
"""
Clears the buffer.
"""
self._size = 0
self._ptr = 0
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment