Skip to content
Snippets Groups Projects
Commit 2013b9f9 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Small adjustments

parent bf28fdc1
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ import numpy as np
class ReplayBuffer():
def __init__(self, env, cap):
self._cap = cap
self._cap = max(1,cap)
self._size = 0
self._ptr = 0
......@@ -12,7 +12,6 @@ class ReplayBuffer():
self._costs = np.zeros((cap, ), dtype=np.float64)
self._next_states = np.zeros_like(self._states)
def _add(self, state, action, reward, cost, next_state, start, end):
self._states[start:end] = state
self._actions[start:end] = action
......@@ -22,7 +21,6 @@ class ReplayBuffer():
self._ptr = end
self._size = min(self._size + (end - start), self._cap)
def add(self, state, action, reward, cost, next_state):
b = state.shape[0]
if self._ptr + b <= self._cap:
......@@ -32,7 +30,6 @@ class ReplayBuffer():
self._add(state[:d], action[:d], reward[:d], cost[:d], next_state[:d], start=self._ptr, end=self._cap)
self._add(state[d:], action[d:], reward[d:], cost[d:], next_state[d:], start=0, end=b-d)
def sample(self, n):
idxs = np.random.randint(low=0, high=self._size, size=n)
return self._states[idxs], self._actions[idxs], self._rewards[idxs], \
......@@ -40,8 +37,4 @@ class ReplayBuffer():
def clear(self):
self._size = 0
self._ptr = 0
@property
def size(self):
return self._size
\ No newline at end of file
self._ptr = 0
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment