Skip to content
Snippets Groups Projects
Commit 6f419c63 authored by Philipp Sauer's avatar Philipp Sauer
Browse files

Used with context for writer

parent 083e0eb1
Branches master
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ SIMILAR = 1
DISSIMILAR = 0
class Encoder_DDPG_Adpt_Shield():
def __init__(self, args, env):
def __init__(self, args, env, writer):
# total train steps
self.total_train_steps = args.total_train_steps
......@@ -31,6 +31,9 @@ class Encoder_DDPG_Adpt_Shield():
self.action_high = env.action_space.high
self.action_low = env.action_space.low
# writer
self.writer = writer
# Assert that we have actions in the interval [-1, 1]
assert(np.all(self.action_high == 1))
assert(np.all(self.action_low == -1))
......@@ -280,7 +283,7 @@ class Encoder_DDPG_Adpt_Shield():
)
# train knn
embeddings = self.base_encoder.predict(x_train)
embeddings = self.base_encoder.predict(x_train, verbose=0)
self.knn = NearestNeighbors(n_neighbors=self.neighbours_count_max).fit(embeddings)
self.knn_y = y_train
......@@ -421,12 +424,13 @@ class Encoder_DDPG_Adpt_Shield():
avg_shield_activations /= eval_episodes
# Log statistics
tf.summary.scalar('test/avg_returns', data=avg_reward, step=total_train_steps)
tf.summary.scalar('test/avg_steps', data=avg_steps, step=total_train_steps)
tf.summary.scalar('test/avg_cost', data=avg_cost, step=total_train_steps)
tf.summary.scalar('test/avg_safety_violations', data=avg_safety_violations, step=total_train_steps)
tf.summary.scalar('test/avg_goal_reaches', data=avg_goal_reaches, step=total_train_steps)
tf.summary.scalar('test/avg_shield_activations', data=avg_shield_activations, step=total_train_steps)
with self.writer.as_default(step=total_train_steps):
tf.summary.scalar('test/avg_returns', data=avg_reward)
tf.summary.scalar('test/avg_steps', data=avg_steps)
tf.summary.scalar('test/avg_cost', data=avg_cost)
tf.summary.scalar('test/avg_safety_violations', data=avg_safety_violations)
tf.summary.scalar('test/avg_goal_reaches', data=avg_goal_reaches)
tf.summary.scalar('test/avg_shield_activations', data=avg_shield_activations)
# the RL loop
def train(self, shield_train_start, shield_train_interval, eval_episodes, eval_interval, evaluate=True):
......@@ -503,15 +507,16 @@ class Encoder_DDPG_Adpt_Shield():
total_safety_violations += episode_safety_violations
total_shield_activations += episode_shield_activations
tf.summary.scalar('train/returns', data=episode_reward, step=total_train_steps)
tf.summary.scalar('train/steps', data=episode_steps, step=total_train_steps)
tf.summary.scalar('train/costs', data=episode_cost, step=total_train_steps)
# tf.summary.scalar('train/goal_reaches', data=episode_goal_reaches, step=total_train_steps)
with self.writer.as_default(step=total_train_steps):
tf.summary.scalar('train/returns', data=episode_reward)
tf.summary.scalar('train/steps', data=episode_steps)
tf.summary.scalar('train/costs', data=episode_cost)
# tf.summary.scalar('train/goal_reaches', data=episode_goal_reaches)
tf.summary.scalar('train/total_shield_activations', data=total_shield_activations, step=total_train_steps)
tf.summary.scalar('train/total_episodes', data=total_episodes, step=total_train_steps)
tf.summary.scalar('train/total_safety_violations', data=total_safety_violations, step=total_train_steps)
tf.summary.scalar('train/total_updates', data=total_updates, step=total_train_steps)
tf.summary.scalar('train/total_shield_activations', data=total_shield_activations)
tf.summary.scalar('train/total_episodes', data=total_episodes)
tf.summary.scalar('train/total_safety_violations', data=total_safety_violations)
tf.summary.scalar('train/total_updates', data=total_updates)
# store safety violations for neighbours_count calculation
self.metrics["safety_violations"].append(total_safety_violations)
......@@ -519,7 +524,8 @@ class Encoder_DDPG_Adpt_Shield():
if trained:
# update neighbours value dynamically
self.update_neighbors()
tf.summary.scalar('debug/neighbors', data=self.neighbours_count, step=total_train_steps)
with self.writer.as_default(step=total_train_steps):
tf.summary.scalar('debug/neighbors', data=self.neighbours_count)
# Train encoder
if total_episodes >= shield_next_train_episode:
......
......@@ -84,15 +84,14 @@ np.random.seed(args.seed)
tf.random.set_seed(args.seed)
env = create_env(args)
env.set_seed(args.seed)
_ = env.reset(seed=args.seed)
output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S"))
file_writer = tf.summary.create_file_writer(output_dir)
file_writer.set_as_default()
writer = tf.summary.create_file_writer(output_dir)
with open(os.path.join(output_dir, "config.json"), "w") as file:
json.dump(args.__dict__, file, indent=2)
enc_ddpg = Encoder_DDPG_Adpt_Shield(args, env)
enc_ddpg = Encoder_DDPG_Adpt_Shield(args, env, writer)
#---------------------------------- training ----------------------------------#
enc_ddpg.train(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment