diff --git a/main.py b/main.py index 87d53e6f93e62c4f34be8391a1b61426391307ba..bf05dfb64a1147af9a4fac7ab988b081b365c0af 100644 --- a/main.py +++ b/main.py @@ -130,7 +130,7 @@ def setup(args): ################## @torch.no_grad -def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train=True, shielded=True): +def run_vectorized_exploration(args, env, agent:CSCCQLSAC, buffer:ReplayBuffer, stats:Statistics, train=True, shielded=True): # track currently running and leftover episodes open_episodes = args.train_episodes if train else args.test_episodes running_episodes = args.num_vectorized_envs @@ -236,55 +236,6 @@ def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train if not train: stats.log_test_tensorboard(f"test/{'shielded' if shielded else 'unshielded'}", stats.total_train_steps) -def single_exploration(args, env, agent, buffer, stats, train=True, shielded=True): - # NOTE: Unused function - state, info = env.reset(seed=random.randint(0, 2**31-1)) - episode_reward = 0 - episode_cost = 0 - episode_steps = 0 - done = False - - while not done: - with torch.no_grad(): - action = agent.get_action(np.expand_dims(state, 0), eval=not train).cpu().numpy().squeeze(0) - next_state, reward, cost, terminated, truncated, info = env.step(action) - done = terminated or truncated - episode_reward += reward - episode_cost += cost - episode_steps += 1 - - state = np.expand_dims(state, 0) - next_state = np.expand_dims(next_state, 0) - action = np.expand_dims(action, 0) - reward = np.array([reward]) - cost = np.array([cost]) - - if train: - buffer.add(state, action, reward, cost, next_state, np.array([terminated])) - if stats.total_train_steps >= 10_000: - x = agent.learn(experiences=buffer.sample(n=args.batch_size)) - actor_loss, alpha_loss, critic1_loss, critic2_loss, cql1_scaled_loss, cql2_scaled_loss, current_alpha, cql_alpha_loss, cql_alpha = x - stats.writer.add_scalar("debug/actor_loss", actor_loss, stats.total_updates) - stats.writer.add_scalar("debug/alpha_loss", alpha_loss, stats.total_updates) - stats.writer.add_scalar("debug/critic1_loss", critic1_loss, stats.total_updates) - stats.writer.add_scalar("debug/critic2_loss", critic2_loss, stats.total_updates) - stats.writer.add_scalar("debug/cql1_scaled_loss", cql1_scaled_loss, stats.total_updates) - stats.writer.add_scalar("debug/cql2_scaled_loss", cql2_scaled_loss, stats.total_updates) - stats.writer.add_scalar("debug/current_alpha", current_alpha, stats.total_updates) - stats.writer.add_scalar("debug/cql_alpha_loss", cql_alpha_loss, stats.total_updates) - stats.writer.add_scalar("debug/cql_alpha", cql_alpha, stats.total_updates) - stats.total_updates += 1 - - state = next_state.squeeze(0) - - stats.writer.add_scalar("train/returns", episode_reward, stats.total_train_steps) - stats.writer.add_scalar("train/costs", episode_cost, stats.total_train_steps) - stats.writer.add_scalar("train/steps", episode_steps, stats.total_train_steps) - - stats.total_train_episodes += 1 - stats.total_train_steps += episode_steps - stats.total_train_unsafe += int(episode_cost > args.cost_limit) - ################## # MAIN LOOP ################## @@ -312,8 +263,8 @@ def main(args, env, agent, buffer, stats): buffer.clear() # Test loop (shielded and unshielded) - # for shielded in [True, False]: - run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=True) + for shielded in [True, False]: + run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded) if __name__ == '__main__': args = cmd_args()