diff --git a/main.py b/main.py index 83eee6094bafa62276d76d7cf4945bd4c05fb4ae..d01e3fc16c6ac2bc06e836f4058b6ab070b10660 100644 --- a/main.py +++ b/main.py @@ -120,7 +120,7 @@ def setup(args): return env, agent, buffer, writer ################## -# USEFUL FUNCTIONS +# EXPLORATION ################## @torch.no_grad @@ -155,7 +155,9 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen episode_reward[mask] += reward[mask] episode_cost[mask] += cost[mask] - if done.any() or ~mask.any(): + is_mask_zero = ~mask.any() + + if done.any() or is_mask_zero: avg_steps += episode_steps.sum() avg_reward += episode_reward.sum() avg_cost += episode_cost.sum() @@ -165,13 +167,14 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen total_steps += episode_steps.sum() total_failures += (episode_cost >= args.cost_limit).sum() - if train: + if train and not is_mask_zero: buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask]) mask = np.ones(args.num_vectorized_envs, dtype='bool') episode_steps = np.zeros(args.num_vectorized_envs) episode_reward = np.zeros(args.num_vectorized_envs) episode_cost = np.zeros(args.num_vectorized_envs) + state, _ = env.reset() # auto resets, but is_mask_zero requires us to reset finished_count = episode_count open_episodes = num_episodes - episode_count @@ -179,13 +182,12 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen mask[idx:] = False episode_count += idx - elif train: - buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask]) - if args.enforce_cost_limit: # we dont care about the cost limit while testing - mask = mask and episode_cost < args.cost_limit else: - pass # eval - state = next_state + if train: + buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask]) + if args.enforce_cost_limit: # we dont care about the cost limit while testing + mask = mask & (episode_cost < args.cost_limit) + state = next_state avg_steps /= num_episodes avg_reward /= num_episodes