diff --git a/main.py b/main.py index 4766320bf70fffd9b9ca931db9e82f4b4cab987d..87d53e6f93e62c4f34be8391a1b61426391307ba 100644 --- a/main.py +++ b/main.py @@ -23,20 +23,20 @@ def cmd_args(): help="Set the environment") env_args.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N", help="Set a cost limit/budget") - env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=16, metavar="N", + env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=8, metavar="N", help="Sets the number of vectorized environments") # train and test args train_test_args = parser.add_argument_group('Train and Test') train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N", help="Total number of steps until training is finished") - train_test_args.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N", + train_test_args.add_argument("--train_episodes", action="store", type=int, default=8, metavar="N", help="Number of episodes until policy optimization") train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N", help="Perform evaluation after N * total_train_episodes episodes of training") - train_test_args.add_argument("--update_iterations", action="store", type=int, default=3, metavar="N", + train_test_args.add_argument("--update_iterations", action="store", type=int, default=128, metavar="N", help="Number of updates performed after each training step") - train_test_args.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N", + train_test_args.add_argument("--test_episodes", action="store", type=int, default=16, metavar="N", help="Number of episodes used for testing") # update args @@ -52,7 +52,7 @@ def cmd_args(): # buffer args buffer_args = parser.add_argument_group('Buffer') - buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N", + buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=100_000, metavar="N", help="Define the maximum capacity of the replay buffer") buffer_args.add_argument("--clear_buffer", action="store_true", default=False, help="Clear Replay Buffer after every optimization step") @@ -65,13 +65,13 @@ def cmd_args(): # cql args cql_args = parser.add_argument_group('CQL') cql_args.add_argument("--cql_with_lagrange", action="store_true", default=False, - help="") + help="Enable automatic alpha tuning") cql_args.add_argument("--cql_temp", action="store", type=float, default=1.0, metavar="N", - help="") + help="Set the temperature for the CQL") cql_args.add_argument("--cql_weight", action="store", type=float, default=1.0, metavar="N", - help="") + help="Set the weight for the CQL") cql_args.add_argument("--cql_target_action_gap", action="store", type=float, default=10, metavar="N", - help="") + help="Set the target action gap for the CQL") # csc args csc_args = parser.add_argument_group('CSC') @@ -295,6 +295,11 @@ def main(args, env, agent, buffer, stats): # Training + Update Loop for _ in range(args.train_until_test): + # Check if training is finished + if stats.total_train_steps >= args.total_train_steps: + finished = True + break + # 1. Run exploration for training run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True)