Skip to content
Snippets Groups Projects
Commit acf29237 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Bugfix buffer record_range

parent c26a88d8
Branches
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ class Buffer:
# Its tells us num of times record() was called.
self.buffer_counter = 0
self.buffer_filled = False
# Instead of list of tuples as the exp.replay concept go
# We use different np.arrays for each tuple element
......@@ -33,7 +34,10 @@ class Buffer:
def record(self, obs_tuple):
# Set index to zero if buffer_capacity is exceeded,
# replacing old records
index = self.buffer_counter % self.buffer_capacity
if self.buffer_counter == self.buffer_capacity:
self.buffer_filled = True
self.buffer_counter = 0
index = self.buffer_counter
self.state_buffer[index] = obs_tuple[0]
self.action_buffer[index] = obs_tuple[1]
......@@ -77,7 +81,7 @@ class Buffer:
# We compute the loss and update parameters
def learn(self):
# Get sampling range
record_range = min(self.buffer_counter, self.buffer_capacity)
record_range = self.buffer_capacity if self.buffer_filled else self.buffer_counter
# Randomly sample indices
batch_indices = np.random.choice(record_range, self.batch_size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment