Skip to content
Snippets Groups Projects
Commit 7ac2f5d3 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Bugfix enforce_cost_limit

parent 1c7e5747
No related branches found
No related tags found
No related merge requests found
......@@ -120,7 +120,7 @@ def setup(args):
return env, agent, buffer, writer
##################
# USEFUL FUNCTIONS
# EXPLORATION
##################
@torch.no_grad
......@@ -155,7 +155,9 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
episode_reward[mask] += reward[mask]
episode_cost[mask] += cost[mask]
if done.any() or ~mask.any():
is_mask_zero = ~mask.any()
if done.any() or is_mask_zero:
avg_steps += episode_steps.sum()
avg_reward += episode_reward.sum()
avg_cost += episode_cost.sum()
......@@ -165,13 +167,14 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
total_steps += episode_steps.sum()
total_failures += (episode_cost >= args.cost_limit).sum()
if train:
if train and not is_mask_zero:
buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask])
mask = np.ones(args.num_vectorized_envs, dtype='bool')
episode_steps = np.zeros(args.num_vectorized_envs)
episode_reward = np.zeros(args.num_vectorized_envs)
episode_cost = np.zeros(args.num_vectorized_envs)
state, _ = env.reset() # auto resets, but is_mask_zero requires us to reset
finished_count = episode_count
open_episodes = num_episodes - episode_count
......@@ -179,13 +182,12 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen
mask[idx:] = False
episode_count += idx
elif train:
buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask])
if args.enforce_cost_limit: # we dont care about the cost limit while testing
mask = mask and episode_cost < args.cost_limit
else:
pass # eval
state = next_state
if train:
buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask])
if args.enforce_cost_limit: # we dont care about the cost limit while testing
mask = mask & (episode_cost < args.cost_limit)
state = next_state
avg_steps /= num_episodes
avg_reward /= num_episodes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment