diff --git a/src/policy.py b/src/policy.py index 722070dc3e6a21cfada441a2a33e1f8244bbc1f0..80e07a6371c437e1ac11418a879608273f38f3e4 100644 --- a/src/policy.py +++ b/src/policy.py @@ -265,7 +265,6 @@ class CSCAgent(): result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...] # minimum in case only unsafe actions result[~is_zero] = action[~is_zero, argmax[~is_zero], ...] # else first safe action - result = result else: dist = policy.distribution(state) result = dist.sample().squeeze(0)