From be5deebb735a55bcfaa5757792d1fb1abc4a19a4 Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Thu, 6 Mar 2025 22:22:17 +0100 Subject: [PATCH] Reverted vector env to single env --- main.py | 201 +++++++++++++++++++------------------------ src/buffer.py | 50 ++++------- src/cql_sac/agent.py | 51 +++++------ src/environment.py | 93 +++++++++++++------- src/stats.py | 55 ++++-------- 5 files changed, 206 insertions(+), 244 deletions(-) diff --git a/main.py b/main.py index bf05dfb..23f6cb2 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from src.stats import Statistics from src.buffer import ReplayBuffer from src.cql_sac.agent import CSCCQLSAC -from src.environment import create_environment +from src.environment import create_env ################## # ARGPARSER @@ -21,20 +21,20 @@ def cmd_args(): env_args = parser.add_argument_group('Environment') env_args.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID", help="Set the environment") - env_args.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N", + env_args.add_argument('--env_args', type=str, default='', + help='Environment specific arguments') + env_args.add_argument("--cost_limit", action="store", type=float, default=0.0, metavar="N", help="Set a cost limit/budget") - env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=8, metavar="N", - help="Sets the number of vectorized environments") # train and test args train_test_args = parser.add_argument_group('Train and Test') - train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N", + train_test_args.add_argument("--total_train_steps", action="store", type=int, default=5_000_000, metavar="N", help="Total number of steps until training is finished") - train_test_args.add_argument("--train_episodes", action="store", type=int, default=8, metavar="N", + train_test_args.add_argument("--train_episodes", action="store", type=int, default=4, metavar="N", help="Number of episodes until policy optimization") - train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N", + train_test_args.add_argument("--train_until_test", action="store", type=int, default=4, metavar="N", help="Perform evaluation after N * total_train_episodes episodes of training") - train_test_args.add_argument("--update_iterations", action="store", type=int, default=128, metavar="N", + train_test_args.add_argument("--update_iterations", action="store", type=int, default=64, metavar="N", help="Number of updates performed after each training step") train_test_args.add_argument("--test_episodes", action="store", type=int, default=16, metavar="N", help="Number of episodes used for testing") @@ -43,16 +43,16 @@ def cmd_args(): update_args = parser.add_argument_group('Update') update_args.add_argument("--batch_size", action="store", type=int, default=256, metavar="N", help="Batch size used for training") - update_args.add_argument("--tau", action="store", type=float, default=5e-3, metavar="N", + update_args.add_argument("--tau", action="store", type=float, default=0.005, metavar="N", help="Factor used in soft update of target network") update_args.add_argument("--gamma", action="store", type=float, default=0.95, metavar="N", help="Discount factor for rewards") - update_args.add_argument("--learning_rate", action="store", type=float, default=3e-4, metavar="N", + update_args.add_argument("--learning_rate", action="store", type=float, default=0.0003, metavar="N", help="Learn rate for the policy and Q networks") # buffer args buffer_args = parser.add_argument_group('Buffer') - buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=100_000, metavar="N", + buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=250_000, metavar="N", help="Define the maximum capacity of the replay buffer") buffer_args.add_argument("--clear_buffer", action="store_true", default=False, help="Clear Replay Buffer after every optimization step") @@ -88,7 +88,7 @@ def cmd_args(): # common args common_args = parser.add_argument_group('Common') - common_args.add_argument("--seed", action="store", type=int, default=42, metavar="N", + common_args.add_argument("--seed", action="store", type=int, default=0, metavar="N", help="Set a custom seed for the rng") common_args.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE", help="Set the device for pytorch to use") @@ -118,7 +118,7 @@ def setup(args): with open(os.path.join(output_dir, "config.json"), "w") as file: json.dump(args.__dict__, file, indent=2) - env = create_environment(args=args) + env = create_env(args=args) buffer = ReplayBuffer(env=env, cap=args.buffer_capacity) stats = Statistics(writer=writer) agent = CSCCQLSAC(env=env, args=args, stats=stats) @@ -130,111 +130,85 @@ def setup(args): ################## @torch.no_grad -def run_vectorized_exploration(args, env, agent:CSCCQLSAC, buffer:ReplayBuffer, stats:Statistics, train=True, shielded=True): - # track currently running and leftover episodes - open_episodes = args.train_episodes if train else args.test_episodes - running_episodes = args.num_vectorized_envs - - # initialize mask and stats per episode - mask = np.ones(args.num_vectorized_envs, dtype=np.bool_) - episode_steps = np.zeros_like(mask, dtype=np.uint64) - episode_reward = np.zeros_like(mask, dtype=np.float64) - episode_cost = np.zeros_like(mask, dtype=np.float64) - - # adjust mask in case we have fewer runs than environments - if open_episodes < args.num_vectorized_envs: - mask[open_episodes:] = False - running_episodes = open_episodes - open_episodes -= running_episodes - - state, info = env.reset(seed=random.randint(0, 2**31-1)) +def run_single_exploration(args, env, agent:CSCCQLSAC, buffer:ReplayBuffer, stats:Statistics, train=True, shielded=True): + # track leftover episodes + num_episodes = args.train_episodes if train else args.test_episodes + + avg_steps = 0. + avg_return = 0. + avg_cost = 0. + avg_unsafe = 0. - while running_episodes > 0: - # sample and execute actions - actions = agent.get_action(state, shielded=shielded).cpu().numpy() - next_state, reward, cost, terminated, truncated, info = env.step(actions) - done = terminated | truncated - not_done_masked = ((~done) & mask) - done_masked = done & mask - done_masked_count = done_masked.sum() - - # increment stats - episode_steps[mask] += 1 - episode_reward[mask] += reward[mask] - episode_cost[mask] += cost[mask] - - # if any run has finished, we need to take special care - # 1. train: extract final_observation from info dict (single envs autoreset, no manual reset needed) and add to buffer - # 2. train+test: log episode stats using Statistics class - # 3. train+test: reset episode stats - # 4. train+test: adjust mask (if necessary) - if done_masked_count > 0: + for _ in range(num_episodes): + # initialize stats per episode + episode_steps = 0 + episode_reward = 0. + episode_cost = 0. + done = False + + # reset env + state, info = env.reset(seed=random.randint(0, 2**31-1)) + + while not done: + # sample and execute actions + action = agent.get_action(state, shielded=shielded).cpu().numpy() + next_state, reward, cost, terminated, truncated, info = env.step(action) + done = terminated or truncated + + # increment stats + episode_steps += 1 + episode_reward += reward + episode_cost += cost + + # During training, add experiences to buffer if train: - # add experiences to buffer - buffer.add( - state[not_done_masked], - actions[not_done_masked], - reward[not_done_masked], - cost[not_done_masked], - next_state[not_done_masked], - terminated[not_done_masked] - ) buffer.add( - state[done_masked], - actions[done_masked], - reward[done_masked], - cost[done_masked], - np.stack(info['final_observation'], axis=0)[done_masked], - terminated[done_masked] + state, action, reward, cost, next_state, not truncated ) - - # record finished episodes - ticks = stats.total_train_steps + np.cumsum(episode_steps[done_masked], axis=0) - stats.log_tensorboard("train/returns", episode_reward[done_masked], ticks) - stats.log_tensorboard("train/costs", episode_cost[done_masked], ticks) - stats.log_tensorboard("train/steps", episode_steps[done_masked], ticks) - stats.log_tensorboard("train/unsafe", (episode_cost[done_masked] > args.cost_limit).astype(np.uint8), ticks) - - stats.log_train_history((episode_cost[done_masked] > args.cost_limit).astype(np.uint8)) - - stats.total_train_episodes += done_masked_count - stats.total_train_steps += episode_steps[done_masked].sum() - stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum() - else: - # record finished episodes - # stats module performs averaging over all when logging episodes - stats.log_test_history( - steps=episode_steps[done_masked], - reward=episode_reward[done_masked], - cost=episode_cost[done_masked], - unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8) - ) + # End episode if the cost exceeded the limit + if episode_cost > args.cost_limit: + done = True - # reset episode stats - state = next_state - episode_steps[done_masked] = 0 - episode_reward[done_masked] = 0 - episode_cost[done_masked] = 0 - - # adjust mask, running and open episodes counter - if open_episodes < done_masked_count: # fewer left than just finished - done_masked_idxs = done_masked.nonzero()[0] - mask[done_masked_idxs[open_episodes:]] = False - running_episodes -= (done_masked_count - open_episodes) - open_episodes = 0 - else: # at least as many left than just finished - open_episodes -= done_masked_count - - # no run has finished, just record experiences (if training) - else: - if train: - buffer.add(state[mask], actions[mask], reward[mask], cost[mask], next_state[mask], done[mask]) + # Update state state = next_state - # average and log finished test episodes - if not train: - stats.log_test_tensorboard(f"test/{'shielded' if shielded else 'unshielded'}", stats.total_train_steps) + # Update statistics + count_unsafe = int(episode_cost > args.cost_limit) + stats.total_train_episodes += 1 + stats.total_train_steps += episode_steps + stats.total_train_unsafe += count_unsafe + + # record finished episodes + stats.log_train_tensorboard( + episode_steps=episode_steps, + episode_return=episode_reward, + episode_cost=episode_cost + ) + + avg_return += episode_reward + avg_steps += episode_steps + avg_cost += episode_cost + avg_unsafe += count_unsafe + + # average and log stats + avg_return /= num_episodes + avg_steps /= num_episodes + avg_cost /= num_episodes + avg_unsafe /= num_episodes + + # In training, update avg_unsafe + # In testing, log averages + if train: + stats.train_unsafe_avg = avg_unsafe + else: + stats.log_test_tensorboard( + avg_steps=avg_steps, + avg_return=avg_return, + avg_cost=avg_cost, + avg_unsafe=avg_unsafe, + shielded=shielded + ) ################## # MAIN LOOP @@ -252,11 +226,12 @@ def main(args, env, agent, buffer, stats): break # 1. Run exploration for training - run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True) + run_single_exploration(args, env, agent, buffer, stats, train=True, shielded=True) # 2. Perform updates for _ in range(args.update_iterations): - agent.learn(experiences=buffer.sample(n=args.batch_size)) + experiences=buffer.sample(n=args.batch_size) + agent.learn(experiences=experiences) # 3. After update stuff if args.clear_buffer: @@ -264,7 +239,7 @@ def main(args, env, agent, buffer, stats): # Test loop (shielded and unshielded) for shielded in [True, False]: - run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded) + run_single_exploration(args, env, agent, buffer, stats, train=False, shielded=shielded) if __name__ == '__main__': args = cmd_args() diff --git a/src/buffer.py b/src/buffer.py index 6b72adb..c1c839a 100644 --- a/src/buffer.py +++ b/src/buffer.py @@ -2,54 +2,37 @@ import numpy as np class ReplayBuffer(): """ - Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. Handles batched experiences. + Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. """ def __init__(self, env, cap): self._cap = max(1,cap) self._size = 0 # number of experiences in the buffer self._ptr = 0 # pointer to the next available slot in the buffer - self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64) - self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64) + self._states = np.zeros((cap, env.observation_space.shape[0]), dtype=np.float64) + self._actions = np.zeros((cap, env.action_space.shape[0]), dtype=np.float64) self._rewards = np.zeros((cap, ), dtype=np.float64) self._costs = np.zeros((cap, ), dtype=np.float64) self._next_states = np.zeros_like(self._states) self._dones = np.zeros((cap, ), dtype=np.uint8) - - def _add(self, state, action, reward, cost, next_state, done, start, end): - self._states[start:end] = state - self._actions[start:end] = action - self._rewards[start:end] = reward - self._costs[start:end] = cost - self._next_states[start:end] = next_state - self._dones[start:end] = done - + def _add(self, state, action, reward, cost, next_state, done): + idx = self._ptr + self._states[idx] = state + self._actions[idx] = action + self._rewards[idx] = reward + self._costs[idx] = cost + self._next_states[idx] = next_state + self._dones[idx] = done def add(self, state, action, reward, cost, next_state, done): """ - Adds experiences to the buffer. Assumes batched experiences. + Adds an experience to the buffer. """ - n = state.shape[0] # NOTE: n should be less than or equal to the buffer capacity - idx_start = self._ptr - idx_end = self._ptr + n - - # if the buffer has capacity, add the experiences to the end of the buffer - if idx_end <= self._cap: - self._add(state, action, reward, cost, next_state, done, idx_start, idx_end) - - # if the buffer does not have capacity, add the experiences to the end of the buffer and wrap around - else: - k = self._cap - idx_start - idx_end = n - k - self._add(state[:k], action[:k], reward[:k], cost[:k], next_state[:k], done[:k], start=idx_start, end=self._cap) - self._add(state[k:], action[k:], reward[k:], cost[k:], next_state[k:], done[k:], start=0, end=idx_end) - - # update the buffer size and pointer - self._ptr = idx_end - if self._size < self._cap: - self._size = min(self._cap, self._size + n) - + self._add(state, action, reward, cost, next_state, done) + self._ptr += 1 + self._size = max(self._size, self._ptr) + self._ptr = self._ptr % self._cap def sample(self, n): """ @@ -59,7 +42,6 @@ class ReplayBuffer(): return self._states[idxs], self._actions[idxs], self._rewards[idxs], \ self._costs[idxs], self._next_states[idxs], self._dones[idxs] - def clear(self): """ Clears the buffer. diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py index 6f10681..8ed7054 100644 --- a/src/cql_sac/agent.py +++ b/src/cql_sac/agent.py @@ -4,9 +4,7 @@ import torch.nn.functional as F import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from .networks import Critic, Actor -import numpy as np import math -import copy class CSCCQLSAC(nn.Module): @@ -27,8 +25,8 @@ class CSCCQLSAC(nn.Module): super(CSCCQLSAC, self).__init__() self.stats = stats - state_size = env.observation_space.shape[-1] - action_size = env.action_space.shape[-1] + state_size = env.observation_space.shape[0] + action_size = env.action_space.shape[0] hidden_size = args.hidden_size self.action_size = action_size @@ -62,7 +60,7 @@ class CSCCQLSAC(nn.Module): self.csc_avg_unsafe = args.csc_chi self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device) - self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) + self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) # Actor Network self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device) @@ -79,7 +77,7 @@ class CSCCQLSAC(nn.Module): self.safety_critic2_target.load_state_dict(self.safety_critic2.state_dict()) self.safety_critic1_optimizer = optim.Adam(self.safety_critic1.parameters(), lr=self.learning_rate) - self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate) + self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate) # Critic Network (w/ Target Network) self.critic1 = Critic(state_size, action_size, hidden_size).to(self.device) @@ -92,46 +90,43 @@ class CSCCQLSAC(nn.Module): self.critic2_target.load_state_dict(self.critic2.state_dict()) self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=self.learning_rate) - self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate) + self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate) def get_action(self, state, shielded=True): """ Returns shielded actions for given state as per current policy. """ - state = torch.from_numpy(state).float().to(self.device) + state = torch.from_numpy(state).float().to(self.device).unsqueeze(0) if shielded: - # Repeat state, resulting shape: (shield_iterations, batch_size, state_size) - state = state.repeat((self.csc_shield_iterations, 1)).reshape(self.csc_shield_iterations, *state.shape) + # Repeat state, resulting shape: (shield_iterations, state_size) + state = state.repeat((self.csc_shield_iterations, 1)) unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe) - # Sample all 'csc_shield_iterations' actions at once for every state + # Sample all 'csc_shield_iterations' actions at once with torch.no_grad(): action = self.actor_local.get_action(state).to(self.device) # Estimate unsafety of all actions q1 = self.safety_critic1(state, action) q2 = self.safety_critic2(state, action) - unsafety = torch.min(q1, q2).squeeze(2) + unsafety = torch.min(q1, q2).squeeze(1) - # Check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state + # Check for actions that qualify (unsafety <= threshold), locate first (if exists) mask_safe = unsafety <= unsafety_threshold - idx_first_safe = mask_safe.int().argmax(dim=0) - # If all actions are unsafe (> threshold), argmax will return the first unsafe as mask_safe[:,,] == 0 everywhere - mask_all_unsafe = (~mask_safe[idx_first_safe, torch.arange(0, mask_safe.shape[1])]) + # Search for first safe action, check if one exists + idx_first_safe = mask_safe.int().argmax(dim=0).item() + is_safe = mask_safe[idx_first_safe].item() - # We now build an idx to access the action tensor as follows: - # If there was at least one safe action, idx_first_safe will be the first safe action for each state - # If there was no safe action, we retrieve the action with minimum unsafety for each state - idx_0 = idx_first_safe - idx_0[mask_all_unsafe] = unsafety[:, mask_all_unsafe].argmin(dim=0) - idx_1 = torch.arange(0, mask_safe.shape[1]) - - # Access action tensor and return - return action[idx_0, idx_1, :] + # Return first safe or alternatively the safest one + if is_safe: + return action[idx_first_safe, :] + else: + idx_best = unsafety.argmin(dim=0).item() + return action[idx_best, :] else: - return self.actor_local.get_action(state).to(self.device) + return self.actor_local.get_action(state).squeeze(0).to(self.device) def calc_policy_loss(self, states, alpha): actions_pred, log_pis = self.actor_local.evaluate(states) @@ -348,9 +343,9 @@ class CSCCQLSAC(nn.Module): "cql_alpha_loss": cql_alpha_loss.item(), "cql_alpha": cql_alpha.item() } - if self.stats.total_updates % 8 == 0: - self.stats.log_update_tensorboard(data) self.stats.total_updates += 1 + if (self.stats.total_updates - 1) % 8 == 0: + self.stats.log_update_tensorboard(data) return data def soft_update(self, local_model , target_model): diff --git a/src/environment.py b/src/environment.py index dfb8971..87d5411 100644 --- a/src/environment.py +++ b/src/environment.py @@ -1,29 +1,64 @@ -import safety_gymnasium -import gymnasium -import numpy as np - -class Gymnasium2SafetyGymnasium(gymnasium.Wrapper): - def step(self, action): - state, reward, terminated, truncated, info = super().step(action) - if 'cost' in info: - cost = info['cost'] - elif isinstance(reward, (int, float)): - cost = 0 - elif isinstance(reward, np.ndarray): - cost = np.zeros_like(reward) - elif isinstance(reward, list): - cost = [0]*len(reward) - else: - raise NotImplementedError("reward type not recognized") # for now - return state, reward, cost, terminated, truncated, info - -def create_environment(args): - if args.env_id.startswith("Safety"): - env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False) - if args.env_id.startswith("Gymnasium_"): - id = args.env_id[len('Gymnasium_'):] - env = gymnasium.make_vec(id, num_envs=args.num_vectorized_envs, vectorization_mode="sync") - env = Gymnasium2SafetyGymnasium(env) - if args.env_id.startswith("RaceTrack"): - raise NotImplementedError("RaceTrack environment is not implemented yet.") - return env \ No newline at end of file +def _parse_env_args_as_dict(env_args): + import ast + """ + We assume the env args are represented as a dictionary string, i.e. + {name1:value1,name2:value2,...} + """ + if not env_args: env_args = '{}' + d = ast.literal_eval(env_args) + if isinstance(d, dict): + return d + raise ValueError("Expected a dictionary string as env_args!") + + +def _create_env_racetrack(args): + import safety_gymnasium + from racetrackgym.argument_parser import Racetrack_parser + from racetrackgym.wrapper import SafetyGymnasiumRaceTrackEnv + + # NOTE: racetrackgym has a custom parser, use it instead + rt_args = Racetrack_parser().parse(args.env_args.split()) + args.env_args = rt_args.__dict__ + + safety_gymnasium.register(id=args.env_id, entry_point=SafetyGymnasiumRaceTrackEnv) + env = safety_gymnasium.make(id=args.env_id, rt_args=rt_args) + return env + + +def _create_env_safetygym(args): + import safety_gymnasium + from safety_gymnasium.bases.base_task import LidarConf + + # NOTE: we dont have a safety gymnasium argument parser, hence expect arguments as a dictionary string + sg_args = _parse_env_args_as_dict(args.env_args) + args.env_args = sg_args.copy() + + # allow changing the lidar via args + if 'lidar' in sg_args: + lidar_config = dict() + for key, val in sg_args['lidar'].items(): + lidar_config[key] = val + + # Wrap init function with new default values + def new_init(self, **kwargs): + for key, val in lidar_config.items(): + if key not in kwargs: + kwargs[key] = val + return self.__init_original__(**kwargs) + + LidarConf.__init_original__ = LidarConf.__init__ + LidarConf.__init__ = new_init + del sg_args['lidar'] + + env = safety_gymnasium.make(id=args.env_id) + return env + + +def create_env(args): + name = args.env_id + if name == "RaceTrack": + return _create_env_racetrack(args) + elif name.startswith("Safety"): + return _create_env_safetygym(args) + else: + raise RuntimeError("Unkonwn environment: " + name) diff --git a/src/stats.py b/src/stats.py index 4f263bf..af63fd1 100644 --- a/src/stats.py +++ b/src/stats.py @@ -10,51 +10,26 @@ class Statistics: self.total_train_unsafe = 0 self.total_updates = 0 - # Used for calculating average unsafe of previous training episodes - self.train_unsafe_history = [] - self._train_unsafe_avg = 0 - - # Used for averaging test results - self.test_steps_history = [] - self.test_reward_history = [] - self.test_cost_history = [] - self.test_unsafe_history = [] - - @property - def train_unsafe_avg(self): - if len(self.train_unsafe_history) > 0: - self._train_unsafe_avg = sum(self.train_unsafe_history) / len(self.train_unsafe_history) - self.train_unsafe_history.clear() - return self._train_unsafe_avg + # Store avg unsafe episodes for CSC + self.train_unsafe_avg = 0 - def log_train_history(self, unsafe: np.ndarray): - self.train_unsafe_history += unsafe.tolist() + def log_train_tensorboard(self, episode_steps, episode_return, episode_cost): + name_prefix = f"train/" + self.log_tensorboard(name_prefix + 'steps', episode_steps, self.total_train_steps) + self.log_tensorboard(name_prefix + 'returns', episode_return, self.total_train_steps) + self.log_tensorboard(name_prefix + 'costs', episode_cost, self.total_train_steps) - def log_test_history(self, steps: np.ndarray, reward: np.ndarray, cost: np.ndarray, unsafe: np.ndarray): - self.test_steps_history += steps.tolist() - self.test_reward_history += reward.tolist() - self.test_cost_history += cost.tolist() - self.test_unsafe_history += unsafe.tolist() - - def log_test_tensorboard(self, name_prefix, ticks): - self.log_tensorboard(name_prefix + '/avg_steps', np.array(self.test_steps_history).mean(), ticks) - self.log_tensorboard(name_prefix + '/avg_returns', np.array(self.test_reward_history).mean(), ticks) - self.log_tensorboard(name_prefix + '/avg_costs', np.array(self.test_cost_history).mean(), ticks) - self.log_tensorboard(name_prefix + '/avg_unsafes', np.array(self.test_unsafe_history).mean(), ticks) - - self.test_steps_history.clear() - self.test_reward_history.clear() - self.test_cost_history.clear() - self.test_unsafe_history.clear() + def log_test_tensorboard(self, avg_steps, avg_return, avg_cost, avg_unsafe, shielded=True): + name_prefix = f"test/{'shielded' if shielded else 'unshielded'}/" + self.log_tensorboard(name_prefix + 'avg_steps', avg_steps, self.total_train_steps) + self.log_tensorboard(name_prefix + 'avg_return', avg_return, self.total_train_steps) + self.log_tensorboard(name_prefix + 'avg_cost', avg_cost, self.total_train_steps) + self.log_tensorboard(name_prefix + 'avg_unsafe', avg_unsafe, self.total_train_steps) def log_update_tensorboard(self, data:dict): for k, v in data.items(): name = f"update/{k}" self.writer.add_scalar(name, v, self.total_updates) - def log_tensorboard(self, name: str, values: np.ndarray, ticks: np.ndarray): - if values.size == 1: - self.writer.add_scalar(name, values.item(), ticks.item()) - else: - for v, t in zip(values, ticks): - self.writer.add_scalar(name, v, t) \ No newline at end of file + def log_tensorboard(self, name, value, step): + self.writer.add_scalar(name, value, step) \ No newline at end of file -- GitLab