From e07e8788d54cf5ec6eda46a617c095bf0df7d915 Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Thu, 6 Mar 2025 15:43:47 +0100
Subject: [PATCH] Minor changes

---
 main.py | 55 +++----------------------------------------------------
 1 file changed, 3 insertions(+), 52 deletions(-)

diff --git a/main.py b/main.py
index 87d53e6..bf05dfb 100644
--- a/main.py
+++ b/main.py
@@ -130,7 +130,7 @@ def setup(args):
 ##################
 
 @torch.no_grad
-def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train=True, shielded=True):
+def run_vectorized_exploration(args, env, agent:CSCCQLSAC, buffer:ReplayBuffer, stats:Statistics, train=True, shielded=True):
     # track currently running and leftover episodes
     open_episodes = args.train_episodes if train else args.test_episodes
     running_episodes = args.num_vectorized_envs
@@ -236,55 +236,6 @@ def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train
     if not train:
         stats.log_test_tensorboard(f"test/{'shielded' if shielded else 'unshielded'}", stats.total_train_steps)
 
-def single_exploration(args, env, agent, buffer, stats, train=True, shielded=True):
-    # NOTE: Unused function
-    state, info = env.reset(seed=random.randint(0, 2**31-1))
-    episode_reward = 0
-    episode_cost = 0
-    episode_steps = 0
-    done = False
-
-    while not done:
-        with torch.no_grad():
-            action = agent.get_action(np.expand_dims(state, 0), eval=not train).cpu().numpy().squeeze(0)
-        next_state, reward, cost, terminated, truncated, info = env.step(action)
-        done = terminated or truncated
-        episode_reward += reward
-        episode_cost += cost
-        episode_steps += 1
-
-        state = np.expand_dims(state, 0)
-        next_state = np.expand_dims(next_state, 0)
-        action = np.expand_dims(action, 0)
-        reward = np.array([reward])
-        cost = np.array([cost])
-
-        if train:
-            buffer.add(state, action, reward, cost, next_state, np.array([terminated]))
-            if stats.total_train_steps >= 10_000:
-                x = agent.learn(experiences=buffer.sample(n=args.batch_size))
-                actor_loss, alpha_loss, critic1_loss, critic2_loss, cql1_scaled_loss, cql2_scaled_loss, current_alpha, cql_alpha_loss, cql_alpha = x
-                stats.writer.add_scalar("debug/actor_loss", actor_loss, stats.total_updates)
-                stats.writer.add_scalar("debug/alpha_loss", alpha_loss, stats.total_updates)
-                stats.writer.add_scalar("debug/critic1_loss", critic1_loss, stats.total_updates)
-                stats.writer.add_scalar("debug/critic2_loss", critic2_loss, stats.total_updates)
-                stats.writer.add_scalar("debug/cql1_scaled_loss", cql1_scaled_loss, stats.total_updates)
-                stats.writer.add_scalar("debug/cql2_scaled_loss", cql2_scaled_loss, stats.total_updates)
-                stats.writer.add_scalar("debug/current_alpha", current_alpha, stats.total_updates)
-                stats.writer.add_scalar("debug/cql_alpha_loss", cql_alpha_loss, stats.total_updates)
-                stats.writer.add_scalar("debug/cql_alpha", cql_alpha, stats.total_updates)
-                stats.total_updates += 1
-
-        state = next_state.squeeze(0)
-        
-    stats.writer.add_scalar("train/returns", episode_reward, stats.total_train_steps)
-    stats.writer.add_scalar("train/costs", episode_cost, stats.total_train_steps)
-    stats.writer.add_scalar("train/steps", episode_steps, stats.total_train_steps)
-
-    stats.total_train_episodes += 1
-    stats.total_train_steps += episode_steps
-    stats.total_train_unsafe += int(episode_cost > args.cost_limit)
-
 ##################
 # MAIN LOOP
 ##################
@@ -312,8 +263,8 @@ def main(args, env, agent, buffer, stats):
                 buffer.clear()
 
         # Test loop (shielded and unshielded)
-        # for shielded in [True, False]:
-        run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=True)
+        for shielded in [True, False]:
+            run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded)
 
 if __name__ == '__main__':
     args = cmd_args()
-- 
GitLab