diff --git a/main.py b/main.py index 7eef356d171125bab10280363f2ecdecc547f607..d9d5b7c0eb05053b4bbe887549ca9274f02fec99 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from src.stats import Statistics from src.buffer import ReplayBuffer -from src.policy import CSCAgent +from src.cql_sac.agent import CQLSAC ################## # ARGPARSER @@ -28,6 +28,8 @@ def cmd_args(): # train and test args train_test_args = parser.add_argument_group('Train and Test') + train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N", + help="Total number of steps until training is finished") train_test_args.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N", help="Number of episodes until policy optimization") train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N", @@ -36,12 +38,17 @@ def cmd_args(): help="Number of updates performed after each training step") train_test_args.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N", help="Number of episodes used for testing") - train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N", - help="Total number of steps until training is finished") - train_test_args.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N", + + # update args + update_args = parser.add_argument_group('Update') + update_args.add_argument("--batch_size", action="store", type=int, default=256, metavar="N", help="Batch size used for training") - train_test_args.add_argument("--tau", action="store", type=float, default=0.05, metavar="N", + update_args.add_argument("--tau", action="store", type=float, default=5e-3, metavar="N", help="Factor used in soft update of target network") + update_args.add_argument("--gamma", action="store", type=float, default=0.95, metavar="N", + help="Discount factor for rewards") + update_args.add_argument("--learning_rate", action="store", type=float, default=3e-4, metavar="N", + help="Learn rate for the policy and Q networks") # buffer args buffer_args = parser.add_argument_group('Buffer') @@ -49,39 +56,35 @@ def cmd_args(): help="Define the maximum capacity of the replay buffer") buffer_args.add_argument("--clear_buffer", action="store_true", default=False, help="Clear Replay Buffer after every optimization step") - + + # network args + network_args = parser.add_argument_group('Networks') + network_args.add_argument("--hidden_size", action="store", type=int, default=256, metavar="N", + help="Hidden size of the networks") + + # cql args + cql_args = parser.add_argument_group('CQL') + cql_args.add_argument("--cql_with_lagrange", action="store", type=float, default=0, metavar="N", + help="") + cql_args.add_argument("--cql_temp", action="store", type=float, default=1.0, metavar="N", + help="") + cql_args.add_argument("--cql_weight", action="store", type=float, default=1.0, metavar="N", + help="") + cql_args.add_argument("--cql_target_action_gap", action="store", type=float, default=10, metavar="N", + help="") + # csc args - csc_args = parser.add_argument_group('Agent') - csc_args.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N", - help="Maximum number of actions sampled during shielding") - csc_args.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N", - help="Maximum number of line search update iterations") - csc_args.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N", - help="Number of samples to estimate expectations") + csc_args = parser.add_argument_group('CSC') csc_args.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N", help="Set the value of chi") csc_args.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N", help="Set the value of delta") - csc_args.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N", - help="Set the value of gamma") csc_args.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N", help="Set the value of beta") csc_args.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N", help="Set the value of alpha") csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N", help="Set the initial value of lambda") - csc_args.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N", - help="Learn rate for the safety critic") - csc_args.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N", - help="Learn rate for the value network") - csc_args.add_argument("--csc_lambda_lr", action="store", type=float, default=4e-2, metavar="N", - help="Learn rate for the lambda dual variable") - csc_args.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N", - help="Hidden dimension of the networks") - csc_args.add_argument("--sigmoid_activation", action="store_true", default=False, - help="Apply sigmoid activation to the safety critics output") - csc_args.add_argument("--shielded_action_sampling", action="store_true", default=False, - help="Sample shielded actions when performing parameter updates") # common args common_args = parser.add_argument_group('Common') @@ -114,10 +117,10 @@ def setup(args): with open(os.path.join(output_dir, "config.json"), "w") as file: json.dump(args.__dict__, file, indent=2) - env = env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False) + env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False) buffer = ReplayBuffer(env=env, cap=args.buffer_capacity) stats = Statistics(writer) - agent = CSCAgent(env, args, buffer, stats) + agent = CQLSAC(env=env, args=args, stats=stats) return env, agent, buffer, stats @@ -147,8 +150,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie while running_episodes > 0: # sample and execute actions - with torch.no_grad(): - actions = agent.sample(state, shielded=shielded).cpu().numpy() + actions = agent.get_action(state, eval=not train).cpu().numpy() next_state, reward, cost, terminated, truncated, info = env.step(actions) done = terminated | truncated not_done_masked = ((~done) & mask) @@ -174,7 +176,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie reward[not_done_masked], cost[not_done_masked], next_state[not_done_masked], - done[not_done_masked] + terminated[not_done_masked] ) buffer.add( state[done_masked], @@ -182,7 +184,7 @@ def run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shie reward[done_masked], cost[done_masked], np.stack(info['final_observation'], axis=0)[done_masked], - done[done_masked] + terminated[done_masked] ) # record finished episodes @@ -247,10 +249,9 @@ def main(args, env, agent, buffer, stats): # 2. Perform updates for _ in range(args.update_iterations): - agent.update() + agent.learn(experiences=buffer.sample(n=args.batch_size)) # 3. After update stuff - agent.after_updates() if args.clear_buffer: buffer.clear()