Skip to content
Snippets Groups Projects
Commit 69b76df7 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Reworked main

parent 88376f26
Branches
No related tags found
No related merge requests found
import argparse import argparse
import safety_gymnasium
import numpy as np import numpy as np
import torch import torch
import random import random
import os import os
import json import json
import datetime import datetime
import safety_gymnasium
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from src.stats import Statistics
from src.buffer import ReplayBuffer from src.buffer import ReplayBuffer
from src.policy import CSCAgent from src.policy import CSCAgent
from src.stats import Statistics
################## ##################
# ARGPARSER # ARGPARSER
################## ##################
def cmd_args(): def cmd_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(formatter_class=lambda prog:argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=40))
# environment args # environment args
parser.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID", env_args = parser.add_argument_group('Environment')
help="Set the environment (default: SafetyPointGoal1-v0)") env_args.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID",
parser.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N", help="Set the environment")
help="Set a cost limit at which point an episode is considered unsafe (default: 25)") env_args.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N",
parser.add_argument("--enforce_cost_limit", action="store_true", default=False, help="Set a cost limit/budget")
help="Aborts episode if cost limit is reached (default: False)") env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=16, metavar="N",
parser.add_argument("--num_vectorized_envs", action="store", type=int, default=16, metavar="N", help="Sets the number of vectorized environments")
help="Sets the number of vectorized environments (default: 16)")
# train and test args
# train args train_test_args = parser.add_argument_group('Train and Test')
parser.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N", train_test_args.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N",
help="Number of episodes until policy optimization (default: 16)") help="Number of episodes until policy optimization")
parser.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N", train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N",
help="Perform evaluation after N * train_episodes episodes of training (default: 2)") help="Perform evaluation after N * total_train_episodes episodes of training")
parser.add_argument("--update_iterations", action="store", type=int, default=3, metavar="N", train_test_args.add_argument("--update_iterations", action="store", type=int, default=3, metavar="N",
help="Number of updates performed after each training step (default: 3)") help="Number of updates performed after each training step")
parser.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N", train_test_args.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N",
help="Number of episodes used for testing (default: 32)") help="Number of episodes used for testing")
parser.add_argument("--total_steps", action="store", type=int, default=25_000_000, metavar="N", train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N",
help="Total number of steps until training is finished (default: 25_000_000)") help="Total number of steps until training is finished")
parser.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N", train_test_args.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N",
help="Batch size used for training (default: 1024)") help="Batch size used for training")
parser.add_argument("--tau", action="store", type=float, default=0.05, metavar="N", train_test_args.add_argument("--tau", action="store", type=float, default=0.05, metavar="N",
help="Factor used in soft update of target network (default: 0.05)") help="Factor used in soft update of target network")
parser.add_argument("--shielded_action_sampling", action="store_true", default=False,
help="Sample shielded actions when performing parameter updates (default: False)")
# buffer args # buffer args
parser.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N", buffer_args = parser.add_argument_group('Buffer')
help="Define the maximum capacity of the replay buffer (default: 50_000)") buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N",
parser.add_argument("--clear_buffer", action="store_true", default=False, help="Define the maximum capacity of the replay buffer")
help="Clear Replay Buffer after every optimization step (default: False)") buffer_args.add_argument("--clear_buffer", action="store_true", default=False,
help="Clear Replay Buffer after every optimization step")
# csc args # csc args
parser.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N", csc_args = parser.add_argument_group('Agent')
help="Maximum number of actions sampled during shielding (default: 100)") csc_args.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N",
parser.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N", help="Maximum number of actions sampled during shielding")
help="Maximum number of line search update iterations (default: 20)") csc_args.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N",
parser.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N", help="Maximum number of line search update iterations")
help="Number of samples to estimate expectations (default: 20)") csc_args.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N",
parser.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N", help="Number of samples to estimate expectations")
help="Set the value of chi (default: 0.05)") csc_args.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N",
parser.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N", help="Set the value of chi")
help="Set the value of delta (default: 0.01)") csc_args.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N",
parser.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N", help="Set the value of delta")
help="Set the value of gamma (default: 0.99)") csc_args.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N",
parser.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N", help="Set the value of gamma")
help="Set the value of beta (default: 0.7)") csc_args.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N",
parser.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N", help="Set the value of beta")
help="Set the value of alpha (default: 0.5)") csc_args.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N",
parser.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N", help="Set the value of alpha")
help="Set the initial value of lambda (default: 1.0)") csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N",
parser.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N", help="Set the initial value of lambda")
help="Learn rate for the safety critic (default: 2e-4)") csc_args.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N",
parser.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N", help="Learn rate for the safety critic")
help="Learn rate for the value network (default: 1e-3)") csc_args.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N",
parser.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N", help="Learn rate for the value network")
help="Learn rate for the lambda dual variable (default: 4e-2)") csc_args.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N",
parser.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N", help="Learn rate for the lambda dual variable")
help="Hidden dimension of the networks (default: 32)") csc_args.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N",
parser.add_argument("--sigmoid_activation", action="store_true", default=False, help="Hidden dimension of the networks")
help="Apply sigmoid activation to the safety critics output (default: False)") csc_args.add_argument("--sigmoid_activation", action="store_true", default=False,
help="Apply sigmoid activation to the safety critics output")
csc_args.add_argument("--shielded_action_sampling", action="store_true", default=False,
help="Sample shielded actions when performing parameter updates")
# common args # common args
parser.add_argument("--seed", action="store", type=int, default=42, metavar="N", common_args = parser.add_argument_group('Common')
help="Set a custom seed for the rng (default: 42)") common_args.add_argument("--seed", action="store", type=int, default=42, metavar="N",
parser.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE", help="Set a custom seed for the rng")
help="Set the device for pytorch to use (default: cuda)") common_args.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE",
parser.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH", help="Set the device for pytorch to use")
help="Set the output and log directory path (default: ./runs)") common_args.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH",
parser.add_argument("--num_threads", action="store", type=int, default=1, metavar="N", help="Set the output and log directory path")
help="Set the maximum number of threads for pytorch and numpy (default: 1)") common_args.add_argument("--num_threads", action="store", type=int, default=1, metavar="N",
help="Set the maximum number of threads for pytorch and numpy")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -98,147 +101,163 @@ def cmd_args(): ...@@ -98,147 +101,163 @@ def cmd_args():
################## ##################
def setup(args): def setup(args):
"""
Performs setup like fixing seeds, initializing env and agent, buffer and stats.
"""
torch.set_num_threads(args.num_threads) torch.set_num_threads(args.num_threads)
os.environ["OMP_NUM_THREADS"] = str(args.num_threads) # export OMP_NUM_THREADS=args.num_threads
os.environ["OPENBLAS_NUM_THREADS"] = str(args.num_threads) # export OPENBLAS_NUM_THREADS=args.num_threads
os.environ["MKL_NUM_THREADS"] = str(args.num_threads) # export MKL_NUM_THREADS=args.num_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = str(args.num_threads) # export VECLIB_MAXIMUM_THREADS=args.num_threads
os.environ["NUMEXPR_NUM_THREADS"] = str(args.num_threads) # export NUMEXPR_NUM_THREADS=args.num_threads
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
torch.set_default_dtype(torch.float64)
output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S")) output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S"))
writer = SummaryWriter(log_dir=output_dir) writer = SummaryWriter(log_dir=output_dir)
with open(os.path.join(output_dir, "config.json"), "w") as file: with open(os.path.join(output_dir, "config.json"), "w") as file:
json.dump(args.__dict__, file, indent=2) json.dump(args.__dict__, file, indent=2)
env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=True) env = env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False)
buffer = ReplayBuffer(env=env, cap=args.buffer_capacity) buffer = ReplayBuffer(env=env, cap=args.buffer_capacity)
agent = CSCAgent(env, args, writer) stats = Statistics(writer)
agent = CSCAgent(env, args, buffer, stats)
stats = Statistics() return env, agent, buffer, stats
return env, agent, buffer, writer, stats
################## ##################
# EXPLORATION # EXPLORATION
################## ##################
@torch.no_grad @torch.no_grad
def run_vectorized_exploration(args, env:safety_gymnasium.vector.VectorEnv, agent, buffer, stats, train, shielded): def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True):
# track currently running and leftover episodes
avg_steps = 0 open_episodes = args.train_episodes if train else args.test_episodes
avg_reward = 0 running_episodes = args.num_vectorized_envs
avg_cost = 0
avg_failures = 0 # initialize mask and stats per episode
mask = np.ones(args.num_vectorized_envs, dtype=np.bool_)
episode_count = args.num_vectorized_envs episode_steps = np.zeros_like(mask, dtype=np.uint64)
finished_count = 0 episode_reward = np.zeros_like(mask, dtype=np.float64)
num_episodes = args.train_episodes if train else args.test_episodes episode_cost = np.zeros_like(mask, dtype=np.float64)
mask = np.ones(args.num_vectorized_envs, dtype='bool') # adjust mask in case we have fewer runs than environments
episode_steps = np.zeros(args.num_vectorized_envs) if open_episodes < args.num_vectorized_envs:
episode_reward = np.zeros(args.num_vectorized_envs) mask[open_episodes:] = False
episode_cost = np.zeros(args.num_vectorized_envs) running_episodes = open_episodes
open_episodes -= running_episodes
if num_episodes < args.num_vectorized_envs:
mask[num_episodes:] = False
episode_count = num_episodes
state, info = env.reset(seed=random.randint(0, 2**31-1)) state, info = env.reset(seed=random.randint(0, 2**31-1))
while finished_count < num_episodes:
action = agent.sample(state, shielded=shielded) while running_episodes > 0:
# sample and execute actions
next_state, reward, cost, terminated, truncated, info = env.step(action) with torch.no_grad():
actions = agent.sample(state, shielded=shielded).cpu().numpy()
next_state, reward, cost, terminated, truncated, info = env.step(actions)
done = terminated | truncated done = terminated | truncated
not_done_masked = ((~done) & mask)
done_masked = done & mask
done_masked_count = done_masked.sum()
# increment stats
episode_steps[mask] += 1 episode_steps[mask] += 1
episode_reward[mask] += reward[mask] episode_reward[mask] += reward[mask]
episode_cost[mask] += cost[mask] episode_cost[mask] += cost[mask]
is_mask_zero = ~mask.any() # if any run has finished, we need to take special care
# 1. train: extract final_observation from info dict (single envs autoreset, no manual reset needed) and add to buffer
if done.any() or is_mask_zero: # 2. train+test: log episode stats using Statistics class
avg_steps += episode_steps.sum() # 3. train+test: reset episode stats
avg_reward += episode_reward.sum() # 4. train+test: adjust mask (if necessary)
avg_cost += episode_cost.sum() if done_masked_count > 0:
avg_failures += (episode_cost >= args.cost_limit).sum()
if train: if train:
stats.total_episodes += episode_count - finished_count # add experiences to buffer
stats.total_steps += episode_steps.sum() buffer.add(
stats.total_failures += (episode_cost >= args.cost_limit).sum() state[not_done_masked],
actions[not_done_masked],
if not is_mask_zero: reward[not_done_masked],
buffer.add(state[mask], action[mask], reward[mask], cost[mask], np.stack(info['final_observation'], axis=0)[mask]) cost[not_done_masked],
next_state[not_done_masked],
mask = np.ones(args.num_vectorized_envs, dtype='bool') done[not_done_masked]
episode_steps = np.zeros(args.num_vectorized_envs) )
episode_reward = np.zeros(args.num_vectorized_envs) buffer.add(
episode_cost = np.zeros(args.num_vectorized_envs) state[done_masked],
state, _ = env.reset() # auto resets, but is_mask_zero requires us to reset actions[done_masked],
reward[done_masked],
finished_count = episode_count cost[done_masked],
open_episodes = num_episodes - episode_count np.stack(info['final_observation'], axis=0)[done_masked],
idx = min(open_episodes, args.num_vectorized_envs) done[done_masked]
mask[idx:] = False )
episode_count += idx
# record finished episodes
stats.record_train(
num_episodes=done_masked_count,
returns=episode_reward[done_masked],
costs=episode_cost[done_masked],
steps=episode_steps[done_masked],
unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
)
stats.total_train_episodes += done_masked_count
stats.total_train_steps += episode_steps[done_masked].sum()
stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum()
else:
# record finished episodes
# stats module performs averaging over all episodes
stats.record_test(
shielded=shielded,
avg_returns=episode_reward[done_masked],
avg_costs=episode_cost[done_masked],
avg_steps=episode_steps[done_masked],
avg_unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8)
)
# reset episode stats
state = next_state
episode_steps[done_masked] = 0
episode_reward[done_masked] = 0
episode_cost[done_masked] = 0
# adjust mask, running and open episodes counter
if open_episodes < done_masked_count: # fewer left than just finished
done_masked_idxs = done_masked.nonzero()[0]
mask[done_masked_idxs[open_episodes:]] = False
running_episodes -= (done_masked_count - open_episodes)
open_episodes = 0
else: # at least as many left than just finished
open_episodes -= done_masked_count
# no run has finished, just record experiences (if training)
else: else:
if train: if train:
buffer.add(state[mask], action[mask], reward[mask], cost[mask], next_state[mask]) buffer.add(state[mask], actions[mask], reward[mask], cost[mask], next_state[mask], done[mask])
if args.enforce_cost_limit: # we dont care about the cost limit while testing
mask = mask & (episode_cost < args.cost_limit)
state = next_state state = next_state
avg_steps /= num_episodes # after exploration, flush stats
avg_reward /= num_episodes stats.after_exploration(train, shielded)
avg_cost /= num_episodes
avg_failures /= num_episodes
return avg_steps, avg_reward, avg_cost, avg_failures
################## ##################
# MAIN LOOP # MAIN LOOP
################## ##################
def main(args, env, agent, buffer, writer, stats:Statistics): def main(args, env, agent, buffer, stats):
finished = False finished = False
while not finished: while not finished:
# Training + Update Loop
for _ in range(args.train_until_test): for _ in range(args.train_until_test):
if stats.total_steps >= args.total_steps:
finished = True # 1. Run exploration for training
break run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True)
stats.begin(name="train")
avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True) # 2. Perform updates
print(f"[TRAIN] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}") for _ in range(args.update_iterations):
stats.end(name="train") agent.update()
stats.begin(name="update")
for i in range(args.update_iterations): # 3. After update stuff
stats.total_updates += 1
agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=stats.total_episodes + i)
stats.end(name="update")
agent.after_updates() agent.after_updates()
if args.clear_buffer: if args.clear_buffer:
buffer.clear() buffer.clear()
stats.begin(name="test") # Test loop (shielded and unshielded)
for shielded, postfix in zip([True, False], ["shielded", "unshielded"]): for shielded in [True, False]:
avg_steps, avg_reward, avg_cost, avg_failures = run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded) run_vectorized_exploration(args, env, agent, buffer, train=False, shielded=shielded)
writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, stats.total_episodes)
writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, stats.total_episodes)
writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, stats.total_episodes)
print(f"[TEST_{postfix.upper()}] avg_steps: {round(avg_steps, 4)}, avg_reward: {round(avg_reward, 4)}, avg_cost: {round(avg_cost, 4)}, avg_failures: {round(avg_failures, 4)}")
stats.end(name="test")
stats.print()
writer.flush()
if __name__ == '__main__': if __name__ == '__main__':
args = cmd_args() args = cmd_args()
env, agent, buffer, writer, stats = setup(args) main(args, *setup(args))
main(args, env, agent, buffer, writer, stats) \ No newline at end of file
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment