Skip to content
Snippets Groups Projects
Unverified Commit efc304ed authored by Break Yang's avatar Break Yang Committed by GitHub
Browse files

Use torch.div for floor div in replay buffer (#1396)

* Use torch.div for floor div in replay buffer

* Address comments
parent 3b64a133
No related branches found
No related tags found
No related merge requests found
......@@ -226,9 +226,9 @@ class ReplayBuffer(RingBuffer):
def _index_to_env_id_idx(self, indices):
"""Convert indices used by SegmentTree to (env_id, idx)."""
# need to use `//` here. Newer versions of pytorch will do automatic
# type promtion and will generate float indices if `/` is used.
env_ids = indices // self._max_length
# Here ``torch.div`` with ``rounding_mode="floor"`` will produce
# result with integer dtype.
env_ids = torch.div(indices, self._max_length, rounding_mode="floor")
return env_ids, indices % self._max_length
def _change_mini_batch_length(self, mini_batch_length):
......@@ -530,10 +530,12 @@ class ReplayBuffer(RingBuffer):
n = (current_pos - idx - 1) / L
"""
# need to use `//` here. Newer versions of pytorch will do automatic
# type promtion and will generate float indices if `/` is used.
return ((self._current_pos[env_ids] - x - 1) //
self._max_length) * self._max_length + x
# Here ``torch.div`` with ``rounding_mode="floor"`` will produce
# result with integer dtype.
return torch.div(
self._current_pos[env_ids] - x - 1,
self._max_length,
rounding_mode="floor") * self._max_length + x
def _set_default_return(self, env_ids):
ind = (env_ids, self.circular(self._current_pos[env_ids] - 1))
......
......@@ -68,10 +68,7 @@ class SegmentTree(nn.Module):
"""
Calculate the parent value from its children.
"""
# need to use `//` here. Newer versions of pytorch will do automatic
# type promtion and will generate float indices if `/` is used.
indices = indices // 2
indices = torch.unique(indices)
indices = torch.unique(indices >> 1)
left = self._values[indices * 2]
right = self._values[indices * 2 + 1]
self._values[indices] = op(left, right)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment