diff --git a/main.py b/main.py index 96a8b0d8ced1a42f8cdbe6cd0e2293ecfe838e60..8215e83da8493c2b95087fe51331d78bf901b928 100644 --- a/main.py +++ b/main.py @@ -1,198 +1,315 @@ import argparse -import safety_gymnasium import numpy as np import torch import random import os import json import datetime - from torch.utils.tensorboard import SummaryWriter + +from src.stats import Statistics from src.buffer import ReplayBuffer -from src.policy import CSCAgent +from src.cql_sac.agent import CSCCQLSAC +from src.environment import create_environment ################## # ARGPARSER ################## +def cmd_args(): + parser = argparse.ArgumentParser(formatter_class=lambda prog:argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=40)) + # environment args + env_args = parser.add_argument_group('Environment') + env_args.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID", + help="Set the environment") + env_args.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N", + help="Set a cost limit/budget") + env_args.add_argument("--num_vectorized_envs", action="store", type=int, default=16, metavar="N", + help="Sets the number of vectorized environments") + + # train and test args + train_test_args = parser.add_argument_group('Train and Test') + train_test_args.add_argument("--total_train_steps", action="store", type=int, default=25_000_000, metavar="N", + help="Total number of steps until training is finished") + train_test_args.add_argument("--train_episodes", action="store", type=int, default=16, metavar="N", + help="Number of episodes until policy optimization") + train_test_args.add_argument("--train_until_test", action="store", type=int, default=2, metavar="N", + help="Perform evaluation after N * total_train_episodes episodes of training") + train_test_args.add_argument("--update_iterations", action="store", type=int, default=3, metavar="N", + help="Number of updates performed after each training step") + train_test_args.add_argument("--test_episodes", action="store", type=int, default=32, metavar="N", + help="Number of episodes used for testing") + + # update args + update_args = parser.add_argument_group('Update') + update_args.add_argument("--batch_size", action="store", type=int, default=256, metavar="N", + help="Batch size used for training") + update_args.add_argument("--tau", action="store", type=float, default=5e-3, metavar="N", + help="Factor used in soft update of target network") + update_args.add_argument("--gamma", action="store", type=float, default=0.95, metavar="N", + help="Discount factor for rewards") + update_args.add_argument("--learning_rate", action="store", type=float, default=3e-4, metavar="N", + help="Learn rate for the policy and Q networks") + + # buffer args + buffer_args = parser.add_argument_group('Buffer') + buffer_args.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N", + help="Define the maximum capacity of the replay buffer") + buffer_args.add_argument("--clear_buffer", action="store_true", default=False, + help="Clear Replay Buffer after every optimization step") + + # network args + network_args = parser.add_argument_group('Networks') + network_args.add_argument("--hidden_size", action="store", type=int, default=256, metavar="N", + help="Hidden size of the networks") + + # cql args + cql_args = parser.add_argument_group('CQL') + cql_args.add_argument("--cql_with_lagrange", action="store_true", default=False, + help="") + cql_args.add_argument("--cql_temp", action="store", type=float, default=1.0, metavar="N", + help="") + cql_args.add_argument("--cql_weight", action="store", type=float, default=1.0, metavar="N", + help="") + cql_args.add_argument("--cql_target_action_gap", action="store", type=float, default=10, metavar="N", + help="") + + # csc args + csc_args = parser.add_argument_group('CSC') + csc_args.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N", + help="Set the value of chi") + csc_args.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N", + help="Set the value of delta") + csc_args.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N", + help="Set the value of beta") + csc_args.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N", + help="Set the value of alpha") + csc_args.add_argument("--csc_lambda", action="store", type=float, default=1.0, metavar="N", + help="Set the initial value of lambda") -parser = argparse.ArgumentParser() -# environment args -parser.add_argument("--env_id", action="store", type=str, default="SafetyPointGoal1-v0", metavar="ID", - help="Set the environment (default: SafetyPointGoal1-v0)") -parser.add_argument("--cost_limit", action="store", type=float, default=25, metavar="N", - help="Set a cost limit at which point an episode is considered unsafe (default: 25)") -parser.add_argument("--enforce_cost_limit", action="store_true", default=False, - help="Aborts episode if cost limit is reached (default: False)") - -# train args -parser.add_argument("--train_episodes", action="store", type=int, default=5, metavar="N", - help="Number of episodes until policy optimization (default: 5)") -parser.add_argument("--train_until_test", action="store", type=int, default=3, metavar="N", - help="Perform evaluation after N * train_episodes episodes of training (default: 3)") -parser.add_argument("--update_iterations", action="store", type=int, default=1, metavar="N", - help="Number of updates performed after each training step (default: 1)") -parser.add_argument("--test_episodes", action="store", type=int, default=5, metavar="N", - help="Number of episodes used for testing (default: 5)") -parser.add_argument("--total_steps", action="store", type=int, default=2_000_000, metavar="N", - help="Total number of steps until training is finished (default: 2_000_000)") -parser.add_argument("--batch_size", action="store", type=int, default=1024, metavar="N", - help="Batch size used for training (default: 1024)") -parser.add_argument("--tau", action="store", type=float, default=0.05, metavar="N", - help="Factor used in soft update of target network (default: 0.05)") - -# buffer args -parser.add_argument("--buffer_capacity", action="store", type=int, default=50_000, metavar="N", - help="Define the maximum capacity of the replay buffer (default: 50_000)") -parser.add_argument("--clear_buffer", action="store_true", default=False, - help="Clear Replay Buffer after every optimization step (default: False)") - -# csc args -parser.add_argument("--shield_iterations", action="store", type=int, default=100, metavar="N", - help="Maximum number of actions sampled during shielding (default: 100)") -parser.add_argument("--line_search_iterations", action="store", type=int, default=20, metavar="N", - help="Maximum number of line search update iterations (default: 20)") -parser.add_argument("--expectation_estimation_samples", action="store", type=int, default=20, metavar="N", - help="Number of samples to estimate expectations (default: 20)") -parser.add_argument("--csc_chi", action="store", type=float, default=0.05, metavar="N", - help="Set the value of chi (default: 0.05)") -parser.add_argument("--csc_delta", action="store", type=float, default=0.01, metavar="N", - help="Set the value of delta (default: 0.01)") -parser.add_argument("--csc_gamma", action="store", type=float, default=0.99, metavar="N", - help="Set the value of gamma (default: 0.99)") -parser.add_argument("--csc_beta", action="store", type=float, default=0.7, metavar="N", - help="Set the value of beta (default: 0.7)") -parser.add_argument("--csc_alpha", action="store", type=float, default=0.5, metavar="N", - help="Set the value of alpha (default: 0.5)") -parser.add_argument("--csc_lambda", action="store", type=float, default=4e-2, metavar="N", - help="Set the initial value of lambda (default: 4e-2)") -parser.add_argument("--csc_safety_critic_lr", action="store", type=float, default=2e-4, metavar="N", - help="Learn rate for the safety critic (default: 2e-4)") -parser.add_argument("--csc_value_network_lr", action="store", type=float, default=1e-3, metavar="N", - help="Learn rate for the value network (default: 1e-3)") -parser.add_argument("--hidden_dim", action="store", type=int, default=32, metavar="N", - help="Hidden dimension of the networks (default: 32)") - -# common args -parser.add_argument("--seed", action="store", type=int, default=42, metavar="N", - help="Set a custom seed for the rng (default: 42)") -parser.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE", - help="Set the device for pytorch to use (default: cuda)") -parser.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH", - help="Set the output and log directory path (default: ./runs)") -parser.add_argument("--num_threads", action="store", type=int, default=32, metavar="N", - help="Set the maximum number of threads for pytorch (default: 32)") -args = parser.parse_args() + # common args + common_args = parser.add_argument_group('Common') + common_args.add_argument("--seed", action="store", type=int, default=42, metavar="N", + help="Set a custom seed for the rng") + common_args.add_argument("--device", action="store", type=str, default="cuda", metavar="DEVICE", + help="Set the device for pytorch to use") + common_args.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH", + help="Set the output and log directory path") + common_args.add_argument("--num_threads", action="store", type=int, default=1, metavar="N", + help="Set the maximum number of threads for pytorch and numpy") + + args = parser.parse_args() + return args ################## # SETUP ################## -torch.set_num_threads(args.num_threads) - -random.seed(args.seed) -np.random.seed(args.seed) -torch.manual_seed(args.seed) -torch.set_default_dtype(torch.float64) -torch.set_default_device(args.device) +def setup(args): + """ + Performs setup like fixing seeds, initializing env and agent, buffer and stats. + """ + torch.set_num_threads(args.num_threads) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) -output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S")) -writer = SummaryWriter(log_dir=output_dir) -with open(os.path.join(output_dir, "config.json"), "w") as file: - json.dump(args.__dict__, file, indent=2) + output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S")) + writer = SummaryWriter(log_dir=output_dir) + with open(os.path.join(output_dir, "config.json"), "w") as file: + json.dump(args.__dict__, file, indent=2) -env = safety_gymnasium.make(id=args.env_id, autoreset=False) -buffer = ReplayBuffer(env=env, cap=args.buffer_capacity) -agent = CSCAgent(env, args, writer) + env = create_environment(args=args) + buffer = ReplayBuffer(env=env, cap=args.buffer_capacity) + stats = Statistics(writer=writer) + agent = CSCCQLSAC(env=env, args=args, stats=stats) -total_episodes = 0 -total_failures = 0 -total_steps = 0 + return env, agent, buffer, stats ################## -# USEFUL FUNCTIONS +# EXPLORATION ################## @torch.no_grad -def run_episode_single(train, shielded): - episode_steps = 0 - episode_reward = 0. - episode_cost = 0. +def run_vectorized_exploration(args, env, agent, buffer, stats:Statistics, train=True, shielded=True): + # track currently running and leftover episodes + open_episodes = args.train_episodes if train else args.test_episodes + running_episodes = args.num_vectorized_envs + + # initialize mask and stats per episode + mask = np.ones(args.num_vectorized_envs, dtype=np.bool_) + episode_steps = np.zeros_like(mask, dtype=np.uint64) + episode_reward = np.zeros_like(mask, dtype=np.float64) + episode_cost = np.zeros_like(mask, dtype=np.float64) + + # adjust mask in case we have fewer runs than environments + if open_episodes < args.num_vectorized_envs: + mask[open_episodes:] = False + running_episodes = open_episodes + open_episodes -= running_episodes + + state, info = env.reset(seed=random.randint(0, 2**31-1)) - state, info = env.reset() + while running_episodes > 0: + # sample and execute actions + actions = agent.get_action(state, eval=not train).cpu().numpy() + next_state, reward, cost, terminated, truncated, info = env.step(actions) + done = terminated | truncated + not_done_masked = ((~done) & mask) + done_masked = done & mask + done_masked_count = done_masked.sum() + + # increment stats + episode_steps[mask] += 1 + episode_reward[mask] += reward[mask] + episode_cost[mask] += cost[mask] + + # if any run has finished, we need to take special care + # 1. train: extract final_observation from info dict (single envs autoreset, no manual reset needed) and add to buffer + # 2. train+test: log episode stats using Statistics class + # 3. train+test: reset episode stats + # 4. train+test: adjust mask (if necessary) + if done_masked_count > 0: + if train: + # add experiences to buffer + buffer.add( + state[not_done_masked], + actions[not_done_masked], + reward[not_done_masked], + cost[not_done_masked], + next_state[not_done_masked], + terminated[not_done_masked] + ) + buffer.add( + state[done_masked], + actions[done_masked], + reward[done_masked], + cost[done_masked], + np.stack(info['final_observation'], axis=0)[done_masked], + terminated[done_masked] + ) + + # record finished episodes + ticks = stats.total_train_steps + np.cumsum(episode_steps[done_masked], axis=0) + stats.log_tensorboard("train/returns", episode_reward[done_masked], ticks) + stats.log_tensorboard("train/costs", episode_cost[done_masked], ticks) + stats.log_tensorboard("train/steps", episode_steps[done_masked], ticks) + stats.log_tensorboard("train/unsafe", (episode_cost[done_masked] > args.cost_limit).astype(np.uint8), ticks) + + stats.log_train_history((episode_cost[done_masked] > args.cost_limit).astype(np.uint8)) + + stats.total_train_episodes += done_masked_count + stats.total_train_steps += episode_steps[done_masked].sum() + stats.total_train_unsafe += (episode_cost[done_masked] > args.cost_limit).sum() + + else: + # record finished episodes + # stats module performs averaging over all when logging episodes + stats.log_test_history( + steps=episode_steps[done_masked], + reward=episode_reward[done_masked], + cost=episode_cost[done_masked], + unsafe=(episode_cost[done_masked] > args.cost_limit).astype(np.uint8) + ) + + # reset episode stats + state = next_state + episode_steps[done_masked] = 0 + episode_reward[done_masked] = 0 + episode_cost[done_masked] = 0 + + # adjust mask, running and open episodes counter + if open_episodes < done_masked_count: # fewer left than just finished + done_masked_idxs = done_masked.nonzero()[0] + mask[done_masked_idxs[open_episodes:]] = False + running_episodes -= (done_masked_count - open_episodes) + open_episodes = 0 + else: # at least as many left than just finished + open_episodes -= done_masked_count + + # no run has finished, just record experiences (if training) + else: + if train: + buffer.add(state[mask], actions[mask], reward[mask], cost[mask], next_state[mask], done[mask]) + state = next_state + + # average and log finished test episodes + if not train: + stats.log_test_tensorboard(f"test/{'shielded' if shielded else 'unshielded'}", stats.total_train_steps) + +def single_exploration(args, env, agent, buffer, stats, train=True, shielded=True): + # NOTE: Unused function + state, info = env.reset(seed=random.randint(0, 2**31-1)) + episode_reward = 0 + episode_cost = 0 + episode_steps = 0 done = False while not done: - action = agent.sample(state, shielded=shielded) + with torch.no_grad(): + action = agent.get_action(np.expand_dims(state, 0), eval=not train).cpu().numpy().squeeze(0) next_state, reward, cost, terminated, truncated, info = env.step(action) done = terminated or truncated - - if train: - buffer.add(state, action, reward, cost, next_state) - state = next_state - - episode_steps += 1 episode_reward += reward episode_cost += cost + episode_steps += 1 - if train and args.enforce_cost_limit: # we dont care about the cost limit while testing - done = done or (episode_cost >= args.cost_limit) - - return episode_steps, episode_reward, episode_cost - -@torch.no_grad -def run_episode_multiple(num_episodes, train, shielded): - global total_episodes, total_failures, total_steps - avg_steps = 0 - avg_reward = 0 - avg_cost = 0 - avg_failures = 0 - - for _ in range(num_episodes): - episode_steps, episode_reward, episode_cost = run_episode_single(train=train, shielded=shielded) - episode_failure = int(episode_cost >= args.cost_limit) + state = np.expand_dims(state, 0) + next_state = np.expand_dims(next_state, 0) + action = np.expand_dims(action, 0) + reward = np.array([reward]) + cost = np.array([cost]) if train: - total_episodes += 1 - total_failures += episode_failure - total_steps += env.spec.max_episode_steps + buffer.add(state, action, reward, cost, next_state, np.array([terminated])) + if stats.total_train_steps >= 10_000: + x = agent.learn(experiences=buffer.sample(n=args.batch_size)) + actor_loss, alpha_loss, critic1_loss, critic2_loss, cql1_scaled_loss, cql2_scaled_loss, current_alpha, cql_alpha_loss, cql_alpha = x + stats.writer.add_scalar("debug/actor_loss", actor_loss, stats.total_updates) + stats.writer.add_scalar("debug/alpha_loss", alpha_loss, stats.total_updates) + stats.writer.add_scalar("debug/critic1_loss", critic1_loss, stats.total_updates) + stats.writer.add_scalar("debug/critic2_loss", critic2_loss, stats.total_updates) + stats.writer.add_scalar("debug/cql1_scaled_loss", cql1_scaled_loss, stats.total_updates) + stats.writer.add_scalar("debug/cql2_scaled_loss", cql2_scaled_loss, stats.total_updates) + stats.writer.add_scalar("debug/current_alpha", current_alpha, stats.total_updates) + stats.writer.add_scalar("debug/cql_alpha_loss", cql_alpha_loss, stats.total_updates) + stats.writer.add_scalar("debug/cql_alpha", cql_alpha, stats.total_updates) + stats.total_updates += 1 - writer.add_scalar("train/episode_reward", episode_reward, total_episodes) - writer.add_scalar("train/episode_cost", episode_cost, total_episodes) - writer.add_scalar("train/episode_failure", episode_failure, total_episodes) + state = next_state.squeeze(0) + stats.writer.add_scalar("train/returns", episode_reward, stats.total_train_steps) + stats.writer.add_scalar("train/costs", episode_cost, stats.total_train_steps) + stats.writer.add_scalar("train/steps", episode_steps, stats.total_train_steps) - avg_steps += episode_steps - avg_reward += episode_reward - avg_cost += episode_cost - avg_failures += episode_failure - - avg_steps /= num_episodes - avg_reward /= num_episodes - avg_cost /= num_episodes - avg_failures /= num_episodes - - return avg_steps, avg_reward, avg_cost, avg_failures + stats.total_train_episodes += 1 + stats.total_train_steps += episode_steps + stats.total_train_unsafe += int(episode_cost > args.cost_limit) ################## # MAIN LOOP ################## -finished = False -while not finished: - for _ in range(args.train_until_test): - if total_steps >= args.total_steps: - finished = True - break - avg_steps, avg_reward, avg_cost, avg_failures = run_episode_multiple(num_episodes=args.train_episodes, train=True, shielded=True) - for i in range(args.update_iterations): - agent.update(buffer=buffer, avg_failures=avg_failures, total_episodes=total_episodes) - if args.clear_buffer: - buffer.clear() - - for shielded, postfix in zip([True, False], ["shielded", "unshielded"]): - avg_steps, avg_reward, avg_cost, avg_failures = run_episode_multiple(num_episodes=args.test_episodes, train=False, shielded=shielded) - writer.add_scalar(f"test/avg_reward_{postfix}", avg_reward, total_episodes) - writer.add_scalar(f"test/avg_cost_{postfix}", avg_cost, total_episodes) - writer.add_scalar(f"test/avg_failures_{postfix}", avg_failures, total_episodes) - -writer.flush() \ No newline at end of file +def main(args, env, agent, buffer, stats): + finished = False + while not finished: + # Training + Update Loop + for _ in range(args.train_until_test): + + # 1. Run exploration for training + run_vectorized_exploration(args, env, agent, buffer, stats, train=True, shielded=True) + + # 2. Perform updates + for _ in range(args.update_iterations): + agent.learn(experiences=buffer.sample(n=args.batch_size)) + + # 3. After update stuff + if args.clear_buffer: + buffer.clear() + + # Test loop (shielded and unshielded) + # for shielded in [True, False]: + run_vectorized_exploration(args, env, agent, buffer, stats, train=False, shielded=True) + +if __name__ == '__main__': + args = cmd_args() + main(args, *setup(args)) \ No newline at end of file diff --git a/src/buffer.py b/src/buffer.py index b74190d5e20ee224520cc79f373c171e133a71d4..6b72adbb5bc8264b4db5d9144b4f9bfabb3e686d 100644 --- a/src/buffer.py +++ b/src/buffer.py @@ -1,33 +1,68 @@ -import numpy -import gymnasium +import numpy as np class ReplayBuffer(): - def __init__(self, env:gymnasium.Env, cap): - self._cap = cap - self._size = 0 - self._ptr = 0 + """ + Buffer for storing experiences. Supports sampling and adding experiences and clearing the buffer. Handles batched experiences. + """ + def __init__(self, env, cap): + self._cap = max(1,cap) + self._size = 0 # number of experiences in the buffer + self._ptr = 0 # pointer to the next available slot in the buffer + + self._states = np.zeros((cap, env.observation_space.shape[-1]), dtype=np.float64) + self._actions = np.zeros((cap, env.action_space.shape[-1]), dtype=np.float64) + self._rewards = np.zeros((cap, ), dtype=np.float64) + self._costs = np.zeros((cap, ), dtype=np.float64) + self._next_states = np.zeros_like(self._states) + self._dones = np.zeros((cap, ), dtype=np.uint8) + + + def _add(self, state, action, reward, cost, next_state, done, start, end): + self._states[start:end] = state + self._actions[start:end] = action + self._rewards[start:end] = reward + self._costs[start:end] = cost + self._next_states[start:end] = next_state + self._dones[start:end] = done + + + def add(self, state, action, reward, cost, next_state, done): + """ + Adds experiences to the buffer. Assumes batched experiences. + """ + n = state.shape[0] # NOTE: n should be less than or equal to the buffer capacity + idx_start = self._ptr + idx_end = self._ptr + n + + # if the buffer has capacity, add the experiences to the end of the buffer + if idx_end <= self._cap: + self._add(state, action, reward, cost, next_state, done, idx_start, idx_end) + + # if the buffer does not have capacity, add the experiences to the end of the buffer and wrap around + else: + k = self._cap - idx_start + idx_end = n - k + self._add(state[:k], action[:k], reward[:k], cost[:k], next_state[:k], done[:k], start=idx_start, end=self._cap) + self._add(state[k:], action[k:], reward[k:], cost[k:], next_state[k:], done[k:], start=0, end=idx_end) + + # update the buffer size and pointer + self._ptr = idx_end + if self._size < self._cap: + self._size = min(self._cap, self._size + n) - self._states = numpy.zeros((cap, env.observation_space.shape[0]), dtype=numpy.float64) - self._actions = numpy.zeros((cap, env.action_space.shape[0]), dtype=numpy.float64) - self._rewards = numpy.zeros((cap, ), dtype=numpy.float64) - self._costs = numpy.zeros((cap, ), dtype=numpy.float64) - self._next_states = numpy.zeros_like(self._states) - - def add(self, state, action, reward, cost, next_state): - self._states[self._ptr] = state - self._actions[self._ptr] = action - self._rewards[self._ptr] = reward - self._costs[self._ptr] = cost - self._next_states[self._ptr] = next_state - self._ptr = (self._ptr+1) % self._cap - self._size = min(self._size+1, self._cap) - def sample(self, n): - idxs = numpy.random.randint(low=0, high=self._size, size=n) + """ + Samples n experiences from the buffer. + """ + idxs = np.random.randint(low=0, high=self._size, size=n) return self._states[idxs], self._actions[idxs], self._rewards[idxs], \ - self._costs[idxs], self._next_states[idxs] + self._costs[idxs], self._next_states[idxs], self._dones[idxs] + def clear(self): + """ + Clears the buffer. + """ self._size = 0 self._ptr = 0 \ No newline at end of file diff --git a/src/cql_sac/agent.py b/src/cql_sac/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6e4cb0cef61d21173db14de1c7c44b4b89f447 --- /dev/null +++ b/src/cql_sac/agent.py @@ -0,0 +1,363 @@ +import torch +import torch.optim as optim +import torch.nn.functional as F +import torch.nn as nn +from torch.nn.utils import clip_grad_norm_ +from .networks import Critic, Actor +import numpy as np +import math +import copy + + +class CSCCQLSAC(nn.Module): + """Interacts with and learns from the environment.""" + + def __init__(self, + env, + args, + stats + ): + """Initialize an Agent object. + + Params + ====== + env : the vector environment + args : the argparse arguments + """ + super(CSCCQLSAC, self).__init__() + self.stats = stats + + state_size = env.observation_space.shape[-1] + action_size = env.action_space.shape[-1] + hidden_size = args.hidden_size + self.action_size = action_size + + self.device = args.device + + self.learning_rate = args.learning_rate + self.gamma = args.gamma + self.tau = args.tau + self.clip_grad_param = 1 + + self.target_entropy = -action_size # -dim(A) + + self.log_alpha = torch.tensor([0.0], requires_grad=True, device=self.device) + self.alpha = self.log_alpha.exp().detach() + self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=self.learning_rate) + + # CQL params + self.cql_with_lagrange = args.cql_with_lagrange + self.cql_temp = args.cql_temp + self.cql_weight = args.cql_weight + self.cql_target_action_gap = args.cql_target_action_gap + self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.device) + self.cql_alpha_optimizer = optim.Adam(params=[self.cql_log_alpha], lr=self.learning_rate) + + # CSC params + self.csc_shield_iterations = 100 + self.csc_alpha = args.csc_alpha + self.csc_beta = args.csc_beta + self.csc_delta = args.csc_delta + self.csc_chi = args.csc_chi + self.csc_avg_unsafe = args.csc_chi + + self.csc_lambda = torch.tensor([args.csc_lambda], requires_grad=True, device=self.device) + self.csc_lambda_optimizer = optim.Adam(params=[self.csc_lambda], lr=self.learning_rate) + + # Actor Network + self.actor_local = Actor(state_size, action_size, hidden_size).to(self.device) + self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=self.learning_rate) + + # Safety Critic Network (w/ Target Network) + self.safety_critic1 = Critic(state_size, action_size, hidden_size).to(self.device) + self.safety_critic2 = Critic(state_size, action_size, hidden_size).to(self.device) + + self.safety_critic1_target = Critic(state_size, action_size, hidden_size).to(self.device) + self.safety_critic1_target.load_state_dict(self.safety_critic1.state_dict()) + + self.safety_critic2_target = Critic(state_size, action_size, hidden_size).to(self.device) + self.safety_critic2_target.load_state_dict(self.safety_critic2.state_dict()) + + self.safety_critic1_optimizer = optim.Adam(self.safety_critic1.parameters(), lr=self.learning_rate) + self.safety_critic2_optimizer = optim.Adam(self.safety_critic2.parameters(), lr=self.learning_rate) + + # Critic Network (w/ Target Network) + self.critic1 = Critic(state_size, action_size, hidden_size).to(self.device) + self.critic2 = Critic(state_size, action_size, hidden_size).to(self.device) + + self.critic1_target = Critic(state_size, action_size, hidden_size).to(self.device) + self.critic1_target.load_state_dict(self.critic1.state_dict()) + + self.critic2_target = Critic(state_size, action_size, hidden_size).to(self.device) + self.critic2_target.load_state_dict(self.critic2.state_dict()) + + self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=self.learning_rate) + self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate) + + + def get_action(self, state, eval=False): + """ + Returns shielded actions for given state as per current policy. + + Note: eval is currently ignored. + """ + state = torch.from_numpy(state).float().to(self.device) + + batch_size = state.shape[0] + unsafety_threshold = (1 - self.gamma) * (self.csc_chi - self.csc_avg_unsafe) + unsafety_best = torch.full((batch_size, ), fill_value=unsafety_threshold+1).to(self.device) + action_best = torch.zeros(batch_size, self.action_size).to(self.device) + + # Run at max 'csc_shield_iterations' iterations to find safe action + for _ in range(self.csc_shield_iterations): + # If all actions are already safe, break + mask_safe = unsafety_best <= unsafety_threshold + if mask_safe.all(): break + + # Sample new actions + with torch.no_grad(): + action = self.actor_local.get_action(state).to(self.device) + + # Estimate safety of new actions + q1 = self.safety_critic1(state, action) + q2 = self.safety_critic2(state, action) + unsafety = torch.min(q1, q2).squeeze(1) + + # Update best actions if they are still unsafe and new actions are safer + mask_update = (~mask_safe) & (unsafety < unsafety_best) + unsafety_best[mask_update] = unsafety[mask_update] + action_best[mask_update] = action[mask_update] + + return action_best + + def calc_policy_loss(self, states, alpha): + actions_pred, log_pis = self.actor_local.evaluate(states) + + q1 = self.critic1(states, actions_pred.squeeze(0)) + q2 = self.critic2(states, actions_pred.squeeze(0)) + min_Q = torch.min(q1,q2) + actor_loss = ((alpha * log_pis - min_Q )).mean() + return actor_loss, log_pis + + def _compute_policy_values(self, obs_pi, obs_q): + #with torch.no_grad(): + actions_pred, log_pis = self.actor_local.evaluate(obs_pi) + qs1 = self.safety_critic1(obs_q, actions_pred) + qs2 = self.safety_critic2(obs_q, actions_pred) + return qs1 - log_pis.detach(), qs2 - log_pis.detach() + + def _compute_random_values(self, obs, actions, critic): + random_values = critic(obs, actions) + random_log_probs = math.log(0.5 ** self.action_size) + return random_values - random_log_probs + + def learn(self, experiences): + """Updates actor, critics and entropy_alpha parameters using given batch of experience tuples. + Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state)) + Critic_loss = MSE(Q, Q_target) + Actor_loss = α * log_pi(a|s) - Q(s,a) + where: + actor_target(state) -> action + critic_target(state, action) -> Q-value + Params + ====== + experiences (Tuple[torch.Tensor]): tuple of (s, a, r, c, s', done) tuples + gamma (float): discount factor + """ + self.csc_avg_unsafe = self.stats.train_unsafe_avg + states, actions, rewards, costs, next_states, dones = experiences + + states = torch.from_numpy(states).float().to(self.device) + actions = torch.from_numpy(actions).float().to(self.device) + rewards = torch.from_numpy(rewards).float().to(self.device).view(-1, 1) + costs = torch.from_numpy(costs).float().to(self.device).view(-1, 1) + next_states = torch.from_numpy(next_states).float().to(self.device) + dones = torch.from_numpy(dones).float().to(self.device).view(-1, 1) + + # ---------------------------- update critic ---------------------------- # + # Get predicted next-state actions and Q values from target models + with torch.no_grad(): + next_action, new_log_pi = self.actor_local.evaluate(next_states) + Q_target1_next = self.critic1_target(next_states, next_action) + Q_target2_next = self.critic2_target(next_states, next_action) + Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha * new_log_pi + # Compute Q targets for current states (y_i) + Q_targets = rewards + (self.gamma * (1 - dones) * Q_target_next) + + # Compute critic loss + q1 = self.critic1(states, actions) + q2 = self.critic2(states, actions) + + critic1_loss = F.mse_loss(q1, Q_targets) + critic2_loss = F.mse_loss(q2, Q_targets) + + # Update critics + # critic 1 + self.critic1_optimizer.zero_grad() + critic1_loss.backward(retain_graph=True) + clip_grad_norm_(self.critic1.parameters(), self.clip_grad_param) + self.critic1_optimizer.step() + # critic 2 + self.critic2_optimizer.zero_grad() + critic2_loss.backward() + clip_grad_norm_(self.critic2.parameters(), self.clip_grad_param) + self.critic2_optimizer.step() + + # ---------------------------- update safety critic ---------------------------- # + # Get predicted next-state actions and Q values from target models + with torch.no_grad(): + next_action, new_log_pi = self.actor_local.evaluate(next_states) + Q_target1_next = self.safety_critic1_target(next_states, next_action) + Q_target2_next = self.safety_critic2_target(next_states, next_action) + Q_target_next = torch.min(Q_target1_next, Q_target2_next) # - self.alpha * new_log_pi + # Compute Q targets for current states (y_i) + Q_targets = costs + (self.gamma * (1 - dones) * Q_target_next) + + # Compute safety_critic loss + q1 = self.safety_critic1(states, actions) + q2 = self.safety_critic2(states, actions) + + safety_critic1_loss = F.mse_loss(q1, Q_targets) + safety_critic2_loss = F.mse_loss(q2, Q_targets) + + # CQL addon + num_repeat = 10 + random_actions = torch.FloatTensor(q1.shape[0] * num_repeat, actions.shape[-1]).uniform_(-1, 1).to(self.device) + temp_states = states.unsqueeze(1).repeat(1, num_repeat, 1).view(states.shape[0] * num_repeat, states.shape[1]) + temp_next_states = next_states.unsqueeze(1).repeat(1, num_repeat, 1).view(next_states.shape[0] * num_repeat, next_states.shape[1]) + + current_pi_values1, current_pi_values2 = self._compute_policy_values(temp_states, temp_states) + next_pi_values1, next_pi_values2 = self._compute_policy_values(temp_next_states, temp_states) + + random_values1 = self._compute_random_values(temp_states, random_actions, self.safety_critic1).reshape(states.shape[0], num_repeat, 1) + random_values2 = self._compute_random_values(temp_states, random_actions, self.safety_critic2).reshape(states.shape[0], num_repeat, 1) + + current_pi_values1 = current_pi_values1.reshape(states.shape[0], num_repeat, 1) + current_pi_values2 = current_pi_values2.reshape(states.shape[0], num_repeat, 1) + + next_pi_values1 = next_pi_values1.reshape(states.shape[0], num_repeat, 1) + next_pi_values2 = next_pi_values2.reshape(states.shape[0], num_repeat, 1) + + cat_q1 = torch.cat([random_values1, current_pi_values1, next_pi_values1], 1) + cat_q2 = torch.cat([random_values2, current_pi_values2, next_pi_values2], 1) + + assert cat_q1.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q1 instead has shape: {cat_q1.shape}" + assert cat_q2.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q2 instead has shape: {cat_q2.shape}" + + # flipped sign of cql1_scaled_loss and cql2_scaled_loss + cql1_scaled_loss = -(torch.logsumexp(cat_q1 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q1.mean() * self.cql_weight) + cql2_scaled_loss = -(torch.logsumexp(cat_q2 / self.cql_temp, dim=1).mean() * self.cql_weight * self.cql_temp) + (q2.mean() * self.cql_weight) + + cql_alpha_loss = torch.FloatTensor([0.0]) + cql_alpha = torch.FloatTensor([1.0]) + if self.cql_with_lagrange: + cql_alpha = torch.clamp(self.cql_log_alpha.exp(), min=0.0, max=1000000.0) + cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.cql_target_action_gap) + cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.cql_target_action_gap) + + self.cql_alpha_optimizer.zero_grad() + cql_alpha_loss = (- cql1_scaled_loss - cql2_scaled_loss) * 0.5 + cql_alpha_loss.backward(retain_graph=True) + self.cql_alpha_optimizer.step() + + total_c1_loss = safety_critic1_loss + cql1_scaled_loss + total_c2_loss = safety_critic2_loss + cql2_scaled_loss + + + # Update safety_critics + # safety_critic 1 + self.safety_critic1_optimizer.zero_grad() + total_c1_loss.backward(retain_graph=True) + clip_grad_norm_(self.safety_critic1.parameters(), self.clip_grad_param) + self.safety_critic1_optimizer.step() + # safety_critic 2 + self.safety_critic2_optimizer.zero_grad() + total_c2_loss.backward() + clip_grad_norm_(self.safety_critic2.parameters(), self.clip_grad_param) + self.safety_critic2_optimizer.step() + + # ---------------------------- update csc lambda ---------------------------- # + # Estimate cost advantage + with torch.no_grad(): + q1 = self.safety_critic1(states, actions) + q2 = self.safety_critic2(states, actions) + v = torch.min(q1, q2) + + new_action, new_log_pi = self.actor_local.evaluate(states) + q1 = self.safety_critic1(states, new_action) + q2 = self.safety_critic2(states, new_action) + q = torch.min(q1, q2) + + cost_advantage = (q - v).mean() + + # Compute csc lambda loss + csc_lambda_loss = -self.csc_lambda*(self.csc_avg_unsafe + (1 / (1 - self.gamma)) * cost_advantage - self.csc_chi) + + self.csc_lambda_optimizer.zero_grad() + csc_lambda_loss.backward() + self.csc_lambda_optimizer.step() + + # ---------------------------- update actor ---------------------------- # + # Estimate reward advantage + q1 = self.critic1(states, actions) + q2 = self.critic2(states, actions) + v = torch.min(q1, q2).detach() + + new_action, new_log_pi = self.actor_local.evaluate(states) + q1 = self.critic1(states, new_action) + q2 = self.critic2(states, new_action) + q = torch.min(q1, q2) + + reward_advantage = q - v + + # Optimize actor + actor_loss = ((self.alpha * new_log_pi - reward_advantage)).mean() + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + # Compute alpha loss + alpha_loss = - (self.log_alpha.exp() * (new_log_pi + self.target_entropy).detach()).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + self.alpha = self.log_alpha.exp().detach() + + # ----------------------- update target networks ----------------------- # + self.soft_update(self.critic1, self.critic1_target) + self.soft_update(self.critic2, self.critic2_target) + self.soft_update(self.safety_critic1, self.safety_critic1_target) + self.soft_update(self.safety_critic2, self.safety_critic2_target) + + # ----------------------- update stats ----------------------- # + data = { + "actor_loss": actor_loss.item(), + "alpha_loss": alpha_loss.item(), + "alpha": self.alpha.item(), + "lambda_loss": csc_lambda_loss.item(), + "lambda": self.csc_lambda.item(), + "critic1_loss": critic1_loss.item(), + "critic2_loss": critic2_loss.item(), + "cql1_scaled_loss": cql1_scaled_loss.item(), + "cql2_scaled_loss": cql2_scaled_loss.item(), + "total_c1_loss": total_c1_loss.item(), + "total_c2_loss": total_c2_loss.item(), + "cql_alpha_loss": cql_alpha_loss.item(), + "cql_alpha": cql_alpha.item() + } + if self.stats.total_updates % 8 == 0: + self.stats.log_update_tensorboard(data) + self.stats.total_updates += 1 + return data + + def soft_update(self, local_model , target_model): + """Soft update model parameters. + θ_target = τ*θ_local + (1 - τ)*θ_target + Params + ====== + local_model: PyTorch model (weights will be copied from) + target_model: PyTorch model (weights will be copied to) + tau (float): interpolation parameter + """ + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data) diff --git a/src/cql_sac/buffer.py b/src/cql_sac/buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf3b751618b42d57e1e38a84762f06138d16eca --- /dev/null +++ b/src/cql_sac/buffer.py @@ -0,0 +1,41 @@ +import numpy as np +import random +import torch +from collections import deque, namedtuple + +class ReplayBuffer: + """Fixed-size buffer to store experience tuples.""" + + def __init__(self, buffer_size, batch_size, device): + """Initialize a ReplayBuffer object. + Params + ====== + buffer_size (int): maximum size of buffer + batch_size (int): size of each training batch + seed (int): random seed + """ + self.device = device + self.memory = deque(maxlen=buffer_size) + self.batch_size = batch_size + self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) + + def add(self, state, action, reward, next_state, done): + """Add a new experience to memory.""" + e = self.experience(state, action, reward, next_state, done) + self.memory.append(e) + + def sample(self): + """Randomly sample a batch of experiences from memory.""" + experiences = random.sample(self.memory, k=self.batch_size) + + states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device) + actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(self.device) + rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device) + next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device) + dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device) + + return (states, actions, rewards, next_states, dones) + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) diff --git a/src/cql_sac/networks.py b/src/cql_sac/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..72f26a3dc986994f68c3bc7a241e475b24da7e7c --- /dev/null +++ b/src/cql_sac/networks.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +from torch.distributions import Normal +import numpy as np +import torch.nn.functional as F + + +def hidden_init(layer): + fan_in = layer.weight.data.size()[0] + lim = 1. / np.sqrt(fan_in) + return (-lim, lim) + +class Actor(nn.Module): + """Actor (Policy) Model.""" + + def __init__(self, state_size, action_size, hidden_size=32, log_std_min=-20, log_std_max=2): + """Initialize parameters and build model. + Params + ====== + state_size (int): Dimension of each state + action_size (int): Dimension of each action + hidden_size (int): Number of nodes in each hidden layer + """ + super(Actor, self).__init__() + self.log_std_min = log_std_min + self.log_std_max = log_std_max + + self.fc1 = nn.Linear(state_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + + self.mu = nn.Linear(hidden_size, action_size) + self.log_std_linear = nn.Linear(hidden_size, action_size) + + def forward(self, state): + x = F.relu(self.fc1(state)) + x = F.relu(self.fc2(x)) + mu = self.mu(x) + + log_std = self.log_std_linear(x) + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + return mu, log_std + + def evaluate(self, state, epsilon=1e-6): + mu, log_std = self.forward(state) + std = log_std.exp() + dist = Normal(mu, std) + e = dist.rsample().to(state.device) + action = torch.tanh(e) + log_prob = (dist.log_prob(e) - torch.log(1 - action.pow(2) + epsilon)).sum(1, keepdim=True) + + return action, log_prob + + + def get_action(self, state): + """ + returns the action based on a squashed gaussian policy. That means the samples are obtained according to: + a(s,e)= tanh(mu(s)+sigma(s)+e) + """ + mu, log_std = self.forward(state) + std = log_std.exp() + dist = Normal(mu, std) + e = dist.rsample().to(state.device) + action = torch.tanh(e) + return action.detach().cpu() + + def get_det_action(self, state): + mu, log_std = self.forward(state) + return torch.tanh(mu).detach().cpu() + + +class Critic(nn.Module): + """Critic (Value) Model.""" + + def __init__(self, state_size, action_size, hidden_size=32): + """Initialize parameters and build model. + Params + ====== + state_size (int): Dimension of each state + action_size (int): Dimension of each action + hidden_size (int): Number of nodes in the network layers + """ + super(Critic, self).__init__() + self.fc1 = nn.Linear(state_size+action_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.fc3 = nn.Linear(hidden_size, 1) + self.reset_parameters() + + def reset_parameters(self): + self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) + self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) + self.fc3.weight.data.uniform_(-3e-3, 3e-3) + + def forward(self, state, action): + """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" + x = torch.cat((state, action), dim=-1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) \ No newline at end of file diff --git a/src/cql_sac/utils.py b/src/cql_sac/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3a87ce2353474c35e3b42b00934f514400d0bf --- /dev/null +++ b/src/cql_sac/utils.py @@ -0,0 +1,43 @@ +import torch +import numpy as np + +def save(args, save_name, model, wandb, ep=None): + import os + save_dir = './trained_models/' + if not os.path.exists(save_dir): + os.makedirs(save_dir) + if not ep == None: + torch.save(model.state_dict(), save_dir + args.run_name + save_name + str(ep) + ".pth") + wandb.save(save_dir + args.run_name + save_name + str(ep) + ".pth") + else: + torch.save(model.state_dict(), save_dir + args.run_name + save_name + ".pth") + wandb.save(save_dir + args.run_name + save_name + ".pth") + +def collect_random(env, dataset, num_samples=200): + state = env.reset() + for _ in range(num_samples): + action = env.action_space.sample() + next_state, reward, done, _ = env.step(action) + dataset.add(state, action, reward, next_state, done) + state = next_state + if done: + state = env.reset() + +def evaluate(env, policy, eval_runs=5): + """ + Makes an evaluation run with the current policy + """ + reward_batch = [] + for i in range(eval_runs): + state = env.reset() + + rewards = 0 + while True: + action = policy.get_action(state, eval=True) + + state, reward, done, _ = env.step(action) + rewards += reward + if done: + break + reward_batch.append(rewards) + return np.mean(reward_batch) \ No newline at end of file diff --git a/src/environment.py b/src/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb89717d076740298985b5eccfa8703497562c7 --- /dev/null +++ b/src/environment.py @@ -0,0 +1,29 @@ +import safety_gymnasium +import gymnasium +import numpy as np + +class Gymnasium2SafetyGymnasium(gymnasium.Wrapper): + def step(self, action): + state, reward, terminated, truncated, info = super().step(action) + if 'cost' in info: + cost = info['cost'] + elif isinstance(reward, (int, float)): + cost = 0 + elif isinstance(reward, np.ndarray): + cost = np.zeros_like(reward) + elif isinstance(reward, list): + cost = [0]*len(reward) + else: + raise NotImplementedError("reward type not recognized") # for now + return state, reward, cost, terminated, truncated, info + +def create_environment(args): + if args.env_id.startswith("Safety"): + env = safety_gymnasium.vector.make(env_id=args.env_id, num_envs=args.num_vectorized_envs, asynchronous=False) + if args.env_id.startswith("Gymnasium_"): + id = args.env_id[len('Gymnasium_'):] + env = gymnasium.make_vec(id, num_envs=args.num_vectorized_envs, vectorization_mode="sync") + env = Gymnasium2SafetyGymnasium(env) + if args.env_id.startswith("RaceTrack"): + raise NotImplementedError("RaceTrack environment is not implemented yet.") + return env \ No newline at end of file diff --git a/src/networks.py b/src/networks.py index 85f6bf9c876b530fc43fcfd59125a67faa166523..5d73c344d59e2718f98ddf48e8d5b09815dacc41 100644 --- a/src/networks.py +++ b/src/networks.py @@ -3,10 +3,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributions import Normal -""" -Copied from spice project. -""" - LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 epsilon = 1e-6 @@ -36,24 +32,23 @@ class ValueNetwork(nn.Module): class QNetwork(nn.Module): - def __init__(self, num_inputs, num_actions, hidden_dim): + def __init__(self, num_inputs, num_actions, hidden_dim, sigmoid_activation=False): super().__init__() - # Q1 architecture self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim) self.linear2 = nn.Linear(hidden_dim, hidden_dim) self.linear3 = nn.Linear(hidden_dim, 1) + self.last_activation = F.sigmoid if sigmoid_activation else nn.Identity() self.apply(weights_init_) def forward(self, state, action): - xu = torch.cat([state, action], 1) - - x1 = F.relu(self.linear1(xu)) - x1 = F.relu(self.linear2(x1)) - x1 = self.linear3(x1) - - return x1#, x2 + x = torch.cat([state, action], -1) + x = F.relu(self.linear1(x)) + x = F.relu(self.linear2(x)) + x = self.linear3(x) + x = self.last_activation(x) + return x class GaussianPolicy(nn.Module): @@ -74,12 +69,9 @@ class GaussianPolicy(nn.Module): self.action_bias = torch.tensor(0.) else: self.action_scale = torch.FloatTensor( - (action_space.high - action_space.low) / 2.) + (action_space.high[0] - action_space.low[0]) / 2.) self.action_bias = torch.FloatTensor( - (action_space.high + action_space.low) / 2.) - - def __call__(self, state, action): - return self.log_prob(state, action) + (action_space.high[0] + action_space.low[0]) / 2.) def forward(self, state): x = F.relu(self.linear1(state)) @@ -90,28 +82,42 @@ class GaussianPolicy(nn.Module): return mean, log_std def distribution(self, state): - mean, log_std = self.forward(state) - std = log_std.exp() - normal = Normal(mean, std, validate_args=False) + try: + mean, log_std = self.forward(state) + std = log_std.exp() + normal = Normal(mean, std) + except ValueError: + print("state:", state) + print(mean) + print(log_std) + print(std) + + print("linear1:", self.linear1) + print(self.linear2) + print(self.mean_linear) + print(self.log_std_linear) + + print("bias:", self.action_bias) + print(self.action_scale) + exit(0) + return normal def log_prob(self, state, action): dist = self.distribution(state) - return dist.log_prob(action) + y_t = (action - self.action_bias) / self.action_scale + x_t = torch.atanh(y_t) + log_prob = dist.log_prob(x_t) + log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) + log_prob = log_prob.sum(axis=-1, keepdim=True) + return log_prob - def sample(self, state): - mean, log_std = self.forward(state) - std = log_std.exp() - normal = Normal(mean, std) - x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + def sample(self, state, num_samples=1): + normal = self.distribution(state) + x_t = normal.rsample((num_samples, )) # for reparameterization trick (mean + std * N(0,1)) y_t = torch.tanh(x_t) action = y_t * self.action_scale + self.action_bias - log_prob = normal.log_prob(x_t) - # Enforcing Action Bound - log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) - log_prob = log_prob.sum(1, keepdim=True) - mean = torch.tanh(mean) * self.action_scale + self.action_bias - return action, log_prob, mean + return action def to(self, device): self.action_scale = self.action_scale.to(device) diff --git a/src/policy.py b/src/policy.py index d15383db54e784be8a2ca5a981c1b31ace07f262..ae73543e1805f2c0bc48ad6f9196468c0227134b 100644 --- a/src/policy.py +++ b/src/policy.py @@ -1,11 +1,13 @@ import torch from torch.distributions import kl_divergence -from torch.func import functional_call, vmap, grad -from .networks import GaussianPolicy, QNetwork, ValueNetwork +from torch.nn.functional import mse_loss +from src.networks import GaussianPolicy, QNetwork, ValueNetwork +from src.helper import soft_update, hard_update, apply_gradient class CSCAgent(): - def __init__(self, env, args, writer) -> None: - self._writer = writer + def __init__(self, env, args, buffer, stats) -> None: + self._buffer = buffer + self._stats = stats self._device = args.device self._shield_iterations = args.shield_iterations @@ -16,58 +18,44 @@ class CSCAgent(): self._gamma = args.csc_gamma self._delta = args.csc_delta self._chi = args.csc_chi - self._avg_failures = self._chi + self._avg_train_unsafe = self._chi self._batch_size = args.batch_size - self._lambda = args.csc_lambda self._expectation_estimation_samples = args.expectation_estimation_samples self._tau = args.tau - - num_inputs = env.observation_space.shape[0] - num_actions = env.action_space.shape[0] - hidden_dim = args.hidden_dim - self._policy = GaussianPolicy(num_inputs, num_actions, hidden_dim, env.action_space).to(self._device) - self._safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device) - self._value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device) - - self._target_safety_critic = QNetwork(num_inputs, num_actions, hidden_dim).to(self._device) - self._target_value_network = ValueNetwork(num_inputs, hidden_dim).to(self._device) - self.soft_update(self._target_safety_critic, self._safety_critic, tau=1) - self.soft_update(self._target_value_network, self._value_network, tau=1) + self._hidden_dim = args.hidden_dim + self._lambda_lr = args.csc_lambda_lr + self._shielded_action_sampling = args.shielded_action_sampling + + num_inputs = env.observation_space.shape[-1] + num_actions = env.action_space.shape[-1] + self._policy = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) + self._safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device) + self._value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device) + self._lambda = torch.nn.Parameter(torch.tensor(args.csc_lambda, requires_grad=True, device=self._device)) + + self._policy_old = GaussianPolicy(num_inputs, num_actions, self._hidden_dim, env.action_space).to(self._device) + self._target_safety_critic = QNetwork(num_inputs, num_actions, self._hidden_dim, sigmoid_activation=args.sigmoid_activation).to(self._device) + self._target_value_network = ValueNetwork(num_inputs, self._hidden_dim).to(self._device) + hard_update(self._policy_old, self._policy) + hard_update(self._target_safety_critic, self._safety_critic) + hard_update(self._target_value_network, self._value_network) self._optim_safety_critic = torch.optim.Adam(self._safety_critic.parameters(), lr=args.csc_safety_critic_lr) self._optim_value_network = torch.optim.Adam(self._value_network.parameters(), lr=args.csc_value_network_lr) - self._policy.eval() - self._safety_critic.eval() - - @staticmethod - @torch.no_grad - def soft_update(target, source, tau): - for tparam, sparam in zip(target.parameters(), source.parameters()): - tparam.copy_((1 - tau) * tparam.data + tau * sparam.data) - - @torch.no_grad - def _cost_advantage(self, states, actions): - cost_actions = self._safety_critic.forward(states, actions) - cost_states = torch.zeros_like(cost_actions) - for _ in range(self._expectation_estimation_samples): - a = self._policy.sample(states)[0].detach() - cost_states += self._target_safety_critic.forward(states, a) - cost_states /= self._expectation_estimation_samples - return cost_actions - cost_states - - - def _reward_diff(self, states, rewards, next_states): + # TODO + self._optim_policy = torch.optim.Adam(self._policy.parameters(), lr=args.csc_safety_critic_lr) + + def _update_value_network(self, states, rewards, next_states, dones): + """ + Updates the reward value network using MSE on a batch of experiences. + """ value_states = self._value_network.forward(states) with torch.no_grad(): value_next_states = self._target_value_network.forward(next_states) + value_next_states *= (1 - dones.view((-1, 1))) value_target = rewards.view((-1, 1)) + self._gamma * value_next_states - return value_target - value_states - - - def _update_value_network(self, states, rewards, next_states): - value_diff = self._reward_diff(states, rewards, next_states) - loss = 1/2 * torch.square(value_diff).mean() + loss = mse_loss(value_states, value_target, reduction='mean') self._optim_value_network.zero_grad() loss.backward() @@ -75,160 +63,238 @@ class CSCAgent(): return loss.item() - def _update_policy(self, states, actions, rewards, next_states): - @torch.no_grad - def estimate_ahat(states, actions, rewards, next_states): - lambda_prime = self._lambda / (1 - self._gamma) - adv_cost = self._cost_advantage(states, actions) - reward_diff = self._reward_diff(states, rewards, next_states) - return reward_diff - lambda_prime * adv_cost - - @torch.no_grad - def apply_gradient(gradient): - l = 0 - for name, param in self._policy.named_parameters(): - if not param.requires_grad: continue - n = param.numel() - param.copy_(param.data + gradient[l:l+n].reshape(param.shape)) - l += n - - def log_prob_grad(states, actions): - # https://pytorch.org/tutorials/intermediate/per_sample_grads.html - params = {k: v.detach() for k, v in self._policy.named_parameters()} - buffers = {k: v.detach() for k, v in self._policy.named_buffers()} - def compute_log_prob(params, buffers, s, a): - s = s.unsqueeze(0) - a = a.unsqueeze(0) - lp = functional_call(self._policy, (params, buffers), (s,a)) - return lp.sum() - - ft_compute_grad = grad(compute_log_prob) - ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) - ft_per_sample_grads = ft_compute_sample_grad(params, buffers, states, actions) - - grads = [] - for key, param in self._policy.named_parameters(): - if not param.requires_grad: continue - g = ft_per_sample_grads[key].flatten(start_dim=1) - grads.append(g) - return torch.hstack(grads) - - @torch.no_grad - def fisher_matrix(glog_prob): - b, n = glog_prob.shape - s,i = 32,0 # batch size, index - fisher = torch.zeros((n,n)) - while i < b: - g = glog_prob[i:i+s, ...] - fisher += (torch.func.vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b - i += s - return fisher - - glog_prob = log_prob_grad(states, actions).detach() - fisher = fisher_matrix(glog_prob) - ahat = estimate_ahat(states, actions, rewards, next_states) - gJ = (glog_prob * ahat).mean(dim=0).view((-1, 1)) - - beta_term = torch.sqrt((2 * self._delta) / (gJ.T @ fisher @ gJ)) - try: - gradient_term = torch.linalg.solve(fisher, gJ) - except RuntimeError: - # https://pytorch.org/docs/stable/generated/torch.linalg.pinv.html - # NOTE: requires full rank on cuda - # gradient_term = torch.linalg.lstsq(fisher.cpu(), gJ.cpu()).solution.to(self._device) # very slow - gradient_term = torch.linalg.pinv(fisher) @ gJ - - del glog_prob, ahat, fisher, gJ + def _update_safety_critic(self, states, actions, costs, next_states, dones): + """ + Updates the safety critic network using the CQL objective (Equation 2) from the CSC paper. + """ + # states, action from old policy (from buffer) + safety_sa_env = self._safety_critic.forward(states, actions) + # states from old policy (from buffer), actions from current policy with torch.no_grad(): - dist_old = self._policy.distribution(states) - beta_j = self._beta - for j in range(self._line_search_iterations): - beta_j = beta_j * (1 - beta_j)**j - gradient = beta_j * beta_term * gradient_term - - apply_gradient(gradient) - dist_new = self._policy.distribution(states) - kl_div = kl_divergence(dist_old, dist_new).mean() - if kl_div <= self._delta: - return j - else: - # NOTE: probably better to save the weights and restore them - apply_gradient(-gradient) - return torch.nan - - - def _update_safety_critic(self, states, actions, costs, next_states): - safety_sa_env = self._safety_critic.forward(states, actions) + actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling) + safety_s_env_a_p = self._safety_critic.forward(states, actions_current_policy) - a = self._policy.sample(states)[0].detach() - safety_s_env_a_p = self._safety_critic.forward(states, a) + # first loss term (steers critic towards overapproximation) + loss1 = (-safety_s_env_a_p.mean() + safety_sa_env.mean()) + # bellman operator with torch.no_grad(): safety_next_state = torch.zeros_like(safety_sa_env) for _ in range(self._expectation_estimation_samples): - a = self._policy.sample(next_states)[0].detach() - safety_next_state += self._target_safety_critic.forward(next_states, a) + actions_current_policy = self.sample(states, shielded=self._shielded_action_sampling) + safety_next_state += self._target_safety_critic.forward(next_states, actions_current_policy) safety_next_state /= self._expectation_estimation_samples - safety_next_state = costs.view((-1, 1)) + self._gamma * safety_next_state - safety_sasc_env = torch.square(safety_sa_env - safety_next_state) + safety_next_state *= (1 - dones.view((-1, 1))) + safety_target = costs.view((-1, 1)) + self._gamma * safety_next_state + + # second loss term (mse loss) + loss2 = mse_loss(safety_sa_env, safety_target, reduction='mean') - loss = self._alpha * (safety_sa_env.mean() - safety_s_env_a_p.mean()) + 1/2 * safety_sasc_env.mean() + # overall weighted loss as sum of loss1 and loss2 + loss = self._alpha * loss1 + 1/2 * loss2 self._optim_safety_critic.zero_grad() loss.backward() self._optim_safety_critic.step() return loss.item() + def _tmp(self, states, actions, rewards, next_states, dones): + # importance sampling weights + action_log_prob = self._policy.log_prob(states, actions) + action_log_prob_old = self._policy_old.log_prob(states, actions).detach() + # avoid underflow by subtracting the max + max_log_prob = torch.max(torch.stack((action_log_prob, action_log_prob_old))) + action_log_prob -= max_log_prob + action_log_prob_old -= max_log_prob + # calculate importance ratio + weighting:torch.Tensor = (action_log_prob - action_log_prob_old).exp() + + # reward advantage + with torch.no_grad(): + value_states = self._target_value_network.forward(states) + value_next_states = self._target_value_network.forward(next_states) + value_next_states *= (1 - dones.view((-1, 1))) + value_target = rewards.view((-1, 1)) + self._gamma * value_next_states + reward_advantage = value_target - value_states + objective = (weighting * reward_advantage).mean() + # objective = (weighting * value_states).mean() + + #objective *= -1 + #self._optim_policy.zero_grad() + #objective.backward() + self._policy.zero_grad() + objective.backward() + if False: + print("alp:", action_log_prob) + print(action_log_prob_old) + print("weighting:", weighting) + print(objective) + print("lin1grad:", self._policy.linear1.weight.grad) + print(self._policy.linear2.weight.grad) + print("meanlingrad:", self._policy.mean_linear.weight.grad) + print(self._policy.log_std_linear.weight.grad) + apply_gradient(self._policy, lr=self._lambda_lr) + #self._optim_policy.step() + + return objective.item() + + + def _primal_dual_gradient(self, states, actions, rewards, next_states, dones): + """ + Updates the policy and lambda parameter. Calculates the objective from equation 45 from the CSC paper. + """ + return self._tmp(states, actions, rewards, next_states, dones) + # importance sampling factor + action_log_prob = self._policy.log_prob(states, actions) + action_log_prob_old = self._policy_old.log_prob(states, actions).detach() + weighting = (action_log_prob - action_log_prob_old).exp() + + # reward advantage + with torch.no_grad(): + value_states = self._target_value_network.forward(states) + value_next_states = self._target_value_network.forward(next_states) + value_next_states *= (1 - dones.view((-1, 1))) + value_target = rewards.view((-1, 1)) + self._gamma * value_next_states + reward_advantage = value_target - value_states + + # lambda prime + lambda_prime = self._lambda / (1 - self._gamma) + + # cost advantage + with torch.no_grad(): + cost_actions = self._target_safety_critic.forward(states, actions) + cost_states = torch.zeros_like(cost_actions) + for _ in range(self._expectation_estimation_samples): + actions_policy_old = self._sample_policy_old(states, shielded=self._shielded_action_sampling) + cost_states += self._target_safety_critic.forward(states, actions_policy_old) + cost_states /= self._expectation_estimation_samples + cost_advantage = cost_actions - cost_states + + # chi prime term (eq. 44) + chi_prime_term = self._lambda * (self._chi - self._avg_train_unsafe) + + # overall objective + objective = (weighting * (reward_advantage - lambda_prime*cost_advantage)).mean() + chi_prime_term + + # reset gradients + self._policy.zero_grad() + self._lambda.grad = None - @torch.no_grad - def _update_lambda(self, states, actions): - gamma_inv = 1 / (1 - self._gamma) - adv = self._cost_advantage(states, actions).mean() - chi_prime = self._chi - self._avg_failures - gradient = (gamma_inv * adv - chi_prime).item() - self._lambda -= gradient - return gradient + # calculate gradients + objective.backward() + # apply policy gradient + apply_gradient(self._policy, lr=1) # gradient ascent - def update(self, buffer, avg_failures, total_episodes): - self._avg_failures = avg_failures - self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures) + # obtain action distributions + action_dist_old = self._policy_old.distribution(states) + action_dist_new = self._policy.distribution(states) + kl_div = kl_divergence(action_dist_old, action_dist_new).mean() - states, actions, rewards, costs, next_states = buffer.sample(self._batch_size) - states = torch.tensor(states) - actions = torch.tensor(actions) - rewards = torch.tensor(rewards) - costs = torch.tensor(costs) - next_states = torch.tensor(next_states) + # revert update + apply_gradient(self._policy, lr=-1) # gradient descent back - piter = self._update_policy(states, actions, rewards, next_states) - vloss = self._update_value_network(states, rewards, next_states) - sloss = self._update_safety_critic(states, actions, costs, next_states) - lgradient = self._update_lambda(states, actions) + # calculate beta + beta_term = torch.sqrt(self._delta / kl_div) + beta_j = self._beta + + # line search + for j in range(self._line_search_iterations): + beta_j = beta_j * (1 - beta_j)**j + beta = beta_j * beta_term + + apply_gradient(self._policy, lr=beta) # gradient ascent + action_dist_new = self._policy.distribution(states) + kl_div = kl_divergence(action_dist_old, action_dist_new).mean() + + if kl_div > self._delta: + apply_gradient(self._policy, lr=-beta) # gradient descent back + else: + break - self.soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau) - self.soft_update(self._target_value_network, self._value_network, tau=self._tau) + # update lambda + with torch.no_grad(): + self._lambda -= self._lambda_lr * self._lambda.grad # gradient descent + self._lambda.clamp_min_(min=0) + + return j - self._writer.add_scalar(f"agent/policy_iterations", piter, total_episodes) - self._writer.add_scalar(f"agent/value_loss", round(vloss,4), total_episodes) - self._writer.add_scalar(f"agent/safety_loss", round(sloss,4), total_episodes) - self._writer.add_scalar(f"agent/lambda_gradient", round(lgradient,4), total_episodes) + def update(self): + """ + Performs one iteration of updates. Updates the value network, policy network, lambda parameter and safety critic. + """ + self._avg_train_unsafe = self._stats.avg_train_unsafe + + states, actions, rewards, costs, next_states, dones = self._buffer.sample(self._batch_size) + states = torch.tensor(states, device=self._device) + actions = torch.tensor(actions, device=self._device) + rewards = torch.tensor(rewards, device=self._device) + costs = torch.tensor(costs, device=self._device) + next_states = torch.tensor(next_states, device=self._device) + dones = torch.tensor(dones, device=self._device) + + vloss = self._update_value_network(states, rewards, next_states, dones) + soft_update(self._target_value_network, self._value_network, tau=self._tau) + piter = self._primal_dual_gradient(states, actions, rewards, next_states, dones) + sloss = self._update_safety_critic(states, actions, costs, next_states, dones) + soft_update(self._target_safety_critic, self._safety_critic, tau=self._tau) + + self._stats.writer.add_scalar("debug/vloss", vloss, self._stats.total_train_episodes) + self._stats.writer.add_scalar("debug/sloss", sloss, self._stats.total_train_episodes) + self._stats.writer.add_scalar("debug/piter", piter, self._stats.total_train_episodes) + self._stats.writer.add_scalar("debug/lambda", self._lambda, self._stats.total_train_episodes) + + + def after_updates(self): + self._unsafety_threshold = (1 - self._gamma) * (self._chi - self._avg_train_unsafe) + hard_update(self._policy_old, self._policy) + + + def _sample_policy_old(self, state, shielded): + """ + Allows action sampling from policy_old. + """ + tmp = self._policy + self._policy = self._policy_old + actions = self.sample(state, shielded=shielded) + self._policy = tmp + return actions def sample(self, state, shielded=True): - state = torch.tensor(state).unsqueeze(0) + """ + Samples and returns one action for every state. If shielded is true, performs rejection sampling according to the CSC paper using the safety critic. + Instead of using a loop, we sample all actions simultaneously and pick: + - the first with an estimated unsafety level <= self._unsafety_threshold (epsilon in the CSC paper) + - or else, the one that achieves maximum safety, i.e., lowest estimated unsafety + """ + state = torch.tensor(state, device=self._device, dtype=torch.float64) if shielded: - state = state.expand((self._shield_iterations, -1)) - action = self._policy.sample(state)[0].detach() - unsafety = self._safety_critic.forward(state, action).squeeze() - mask = (unsafety <= self._unsafety_threshold).nonzero().flatten() - if mask.numel() > 0: - idx = mask[0] - else: - idx = torch.argmin(unsafety) - if idx.numel() > 1: idx = idx[0] - return action[idx].cpu().numpy() + # sample all actions, expand/copy state, estimate safety for all actions at the same time + actions = self._policy.sample(state, num_samples=self._shield_iterations) # shape: (shield_iterations, batch_size, action_size) + state = state.unsqueeze(0).expand((self._shield_iterations, -1, -1)) # shape: (shield_iterations, batch_size, state_size) + unsafety = self._safety_critic.forward(state, actions) # shape: (shield_iterations, batch_size, 1) + unsafety = unsafety.squeeze(2) # shape: (shield_iterations, batch_size) + + # check for actions that qualify (unsafety <= threshold), locate first (if exists) for every state + mask = unsafety <= self._unsafety_threshold # shape: (shield_iterations, batch_size) + row_idx = mask.int().argmax(dim=0) # idx of first "potentially" safe actions + batch_idx = torch.arange(0, mask.shape[1]) + + # retrieve estimated safety of said action and check if it is indeed <= threshold + # if for a state all actions are unsafe (> threshold), argmax will return the first unsafe as mask[state] == 0 everywhere + unsafety_chosen = unsafety[row_idx, batch_idx] + all_unsafe = unsafety_chosen > self._unsafety_threshold + + # for such states retrieve the action with a minimum level of unsafety + row_idx[all_unsafe] = unsafety[:, all_unsafe].argmin(dim=0) + + # sample and return actions according to idxs + result = actions[row_idx, batch_idx, :] + return result + else: - action = self._policy.sample(state)[0].detach().squeeze().cpu().numpy() - return action \ No newline at end of file + # simply sample for every state one action + actions = self._policy.sample(state) # shape: (num_samples, batch_size, action_size) + return actions.squeeze(0) \ No newline at end of file diff --git a/src/stats.py b/src/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..4f263bf73e5ef7fb25774d047297b8fcbfce7bf9 --- /dev/null +++ b/src/stats.py @@ -0,0 +1,60 @@ +import numpy as np + +class Statistics: + def __init__(self, writer): + self.writer = writer + + # Total training statistics + self.total_train_episodes = 0 + self.total_train_steps = 0 + self.total_train_unsafe = 0 + self.total_updates = 0 + + # Used for calculating average unsafe of previous training episodes + self.train_unsafe_history = [] + self._train_unsafe_avg = 0 + + # Used for averaging test results + self.test_steps_history = [] + self.test_reward_history = [] + self.test_cost_history = [] + self.test_unsafe_history = [] + + @property + def train_unsafe_avg(self): + if len(self.train_unsafe_history) > 0: + self._train_unsafe_avg = sum(self.train_unsafe_history) / len(self.train_unsafe_history) + self.train_unsafe_history.clear() + return self._train_unsafe_avg + + def log_train_history(self, unsafe: np.ndarray): + self.train_unsafe_history += unsafe.tolist() + + def log_test_history(self, steps: np.ndarray, reward: np.ndarray, cost: np.ndarray, unsafe: np.ndarray): + self.test_steps_history += steps.tolist() + self.test_reward_history += reward.tolist() + self.test_cost_history += cost.tolist() + self.test_unsafe_history += unsafe.tolist() + + def log_test_tensorboard(self, name_prefix, ticks): + self.log_tensorboard(name_prefix + '/avg_steps', np.array(self.test_steps_history).mean(), ticks) + self.log_tensorboard(name_prefix + '/avg_returns', np.array(self.test_reward_history).mean(), ticks) + self.log_tensorboard(name_prefix + '/avg_costs', np.array(self.test_cost_history).mean(), ticks) + self.log_tensorboard(name_prefix + '/avg_unsafes', np.array(self.test_unsafe_history).mean(), ticks) + + self.test_steps_history.clear() + self.test_reward_history.clear() + self.test_cost_history.clear() + self.test_unsafe_history.clear() + + def log_update_tensorboard(self, data:dict): + for k, v in data.items(): + name = f"update/{k}" + self.writer.add_scalar(name, v, self.total_updates) + + def log_tensorboard(self, name: str, values: np.ndarray, ticks: np.ndarray): + if values.size == 1: + self.writer.add_scalar(name, values.item(), ticks.item()) + else: + for v, t in zip(values, ticks): + self.writer.add_scalar(name, v, t) \ No newline at end of file