diff --git a/src/buffer.py b/src/buffer.py
index 2452601d282de134723ce6383259400c811b1db3..6b72adbb5bc8264b4db5d9144b4f9bfabb3e686d 100644
--- a/src/buffer.py
+++ b/src/buffer.py
@@ -1,40 +1,68 @@
 import numpy as np
 
 class ReplayBuffer():
+    """
+    Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. Handles batched experiences.
+    """
     def __init__(self, env, cap):
         self._cap = max(1,cap)
-        self._size = 0
-        self._ptr = 0
+        self._size = 0              # number of experiences in the buffer
+        self._ptr = 0               # pointer to the next available slot in the buffer
 
         self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64)
         self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64)
         self._rewards = np.zeros((cap, ), dtype=np.float64)
         self._costs = np.zeros((cap, ), dtype=np.float64)
         self._next_states = np.zeros_like(self._states)
-    
-    def _add(self, state, action, reward, cost, next_state, start, end):
+        self._dones = np.zeros((cap, ), dtype=np.uint8)
+
+
+    def _add(self, state, action, reward, cost, next_state, done, start, end):
         self._states[start:end] = state
         self._actions[start:end] = action
         self._rewards[start:end] = reward
         self._costs[start:end] = cost
         self._next_states[start:end] = next_state
-        self._ptr = end
-        self._size = min(self._size + (end - start), self._cap)
+        self._dones[start:end] = done
 
-    def add(self, state, action, reward, cost, next_state):
-        b = state.shape[0]
-        if self._ptr + b <= self._cap:
-            self._add(state, action, reward, cost, next_state, start=self._ptr, end=self._ptr+b)
+
+    def add(self, state, action, reward, cost, next_state, done):
+        """
+        Adds experiences to the buffer. Assumes batched experiences.
+        """
+        n = state.shape[0]          # NOTE: n should be less than or equal to the buffer capacity
+        idx_start = self._ptr
+        idx_end = self._ptr + n
+
+        # if the buffer has capacity, add the experiences to the end of the buffer
+        if idx_end <= self._cap:
+            self._add(state, action, reward, cost, next_state, done, idx_start, idx_end)
+        
+        # if the buffer does not have capacity, add the experiences to the end of the buffer and wrap around
         else:
-            d = self._cap - self._ptr
-            self._add(state[:d], action[:d], reward[:d], cost[:d], next_state[:d], start=self._ptr, end=self._cap)
-            self._add(state[d:], action[d:], reward[d:], cost[d:], next_state[d:], start=0, end=b-d)
+            k = self._cap - idx_start
+            idx_end = n - k
+            self._add(state[:k], action[:k], reward[:k], cost[:k], next_state[:k], done[:k], start=idx_start, end=self._cap)
+            self._add(state[k:], action[k:], reward[k:], cost[k:], next_state[k:], done[k:], start=0, end=idx_end)
+        
+        # update the buffer size and pointer 
+        self._ptr = idx_end
+        if self._size < self._cap:
+            self._size = min(self._cap, self._size + n)
+
 
     def sample(self, n):
+        """
+        Samples n experiences from the buffer.
+        """
         idxs = np.random.randint(low=0, high=self._size, size=n)
         return self._states[idxs], self._actions[idxs], self._rewards[idxs], \
-            self._costs[idxs], self._next_states[idxs]
+            self._costs[idxs], self._next_states[idxs], self._dones[idxs]
     
+
     def clear(self):
+        """
+        Clears the buffer.
+        """
         self._size = 0
         self._ptr = 0
\ No newline at end of file