From 5608549803fa81526f4ff1603b290035a3468d3e Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Mon, 30 Sep 2024 17:00:29 +0200 Subject: [PATCH] Bugfix eval total counter --- main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 259f40f..ed05316 100644 --- a/main.py +++ b/main.py @@ -156,12 +156,13 @@ def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agen avg_cost += episode_cost.sum() avg_failures += (episode_cost >= args.cost_limit).sum() - stats.total_episodes += episode_count - finished_count - stats.total_steps += episode_steps.sum() - stats.total_failures += (episode_cost >= args.cost_limit).sum() + if train: + stats.total_episodes += episode_count - finished_count + stats.total_steps += episode_steps.sum() + stats.total_failures += (episode_cost >= args.cost_limit).sum() - if train and not is_mask_zero: - buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask]) + if not is_mask_zero: + buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask]) mask = np.ones(args.num_vectorized_envs, dtype='bool') episode_steps = np.zeros(args.num_vectorized_envs) -- GitLab