Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
csc
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Philipp Sauer
csc
Commits
69b76df7
Commit
69b76df7
authored
2 months ago
by
Philipp Sauer
Browse files
Options
Downloads
Patches
Plain Diff
Reworked main
parent
88376f26
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
main.py
+188
-169
188 additions, 169 deletions
main.py
with
188 additions
and
169 deletions
main.py
+
188
−
169
View file @
69b76df7
import
argparse
import
argparse
import
safety_gymnasium
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
random
import
random
import
os
import
os
import
json
import
json
import
datetime
import
datetime
import
safety_gymnasium
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
from
src.stats
import
Statistics
from
src.buffer
import
ReplayBuffer
from
src.buffer
import
ReplayBuffer
from
src.policy
import
CSCAgent
from
src.policy
import
CSCAgent
from
src.stats
import
Statistics
##################
##################
# ARGPARSER
# ARGPARSER
##################
##################
def
cmd_args
():
def
cmd_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
lambda
prog
:
argparse
.
ArgumentDefaultsHelpFormatter
(
prog
,
max_help_position
=
40
)
)
# environment args
# environment args
parser
.
add_argument
(
"
--env_id
"
,
action
=
"
store
"
,
type
=
str
,
default
=
"
SafetyPointGoal1-v0
"
,
metavar
=
"
ID
"
,
env_args
=
parser
.
add_argument_group
(
'
Environment
'
)
help
=
"
Set the environment (default: SafetyPointGoal1-v0)
"
)
env_args
.
add_argument
(
"
--env_id
"
,
action
=
"
store
"
,
type
=
str
,
default
=
"
SafetyPointGoal1-v0
"
,
metavar
=
"
ID
"
,
parser
.
add_argument
(
"
--cost_limit
"
,
action
=
"
store
"
,
type
=
float
,
default
=
25
,
metavar
=
"
N
"
,
help
=
"
Set the environment
"
)
help
=
"
Set a cost limit at which point an episode is considered unsafe (default: 25)
"
)
env_args
.
add_argument
(
"
--cost_limit
"
,
action
=
"
store
"
,
type
=
float
,
default
=
25
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--enforce_cost_limit
"
,
action
=
"
store_true
"
,
default
=
False
,
help
=
"
Set a cost limit/budget
"
)
help
=
"
Aborts episode if cost limit is reached (default: False)
"
)
env_args
.
add_argument
(
"
--num_vectorized_envs
"
,
action
=
"
store
"
,
type
=
int
,
default
=
16
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--num_vectorized_envs
"
,
action
=
"
store
"
,
type
=
int
,
default
=
16
,
metavar
=
"
N
"
,
help
=
"
Sets the number of vectorized environments
"
)
help
=
"
Sets the number of vectorized environments (default: 16)
"
)
# train and test args
# train args
train_test_args
=
parser
.
add_argument_group
(
'
Train and Test
'
)
parser
.
add_argument
(
"
--train_episodes
"
,
action
=
"
store
"
,
type
=
int
,
default
=
16
,
metavar
=
"
N
"
,
train_test_args
.
add_argument
(
"
--train_episodes
"
,
action
=
"
store
"
,
type
=
int
,
default
=
16
,
metavar
=
"
N
"
,
help
=
"
Number of episodes until policy optimization (default: 16)
"
)
help
=
"
Number of episodes until policy optimization
"
)
parser
.
add_argument
(
"
--train_until_test
"
,
action
=
"
store
"
,
type
=
int
,
default
=
2
,
metavar
=
"
N
"
,
train_test_args
.
add_argument
(
"
--train_until_test
"
,
action
=
"
store
"
,
type
=
int
,
default
=
2
,
metavar
=
"
N
"
,
help
=
"
Perform evaluation after N * train_episodes episodes of training (default: 2)
"
)
help
=
"
Perform evaluation after N * total_train_episodes episodes of training
"
)
parser
.
add_argument
(
"
--update_iterations
"
,
action
=
"
store
"
,
type
=
int
,
default
=
3
,
metavar
=
"
N
"
,
train_test_args
.
add_argument
(
"
--update_iterations
"
,
action
=
"
store
"
,
type
=
int
,
default
=
3
,
metavar
=
"
N
"
,
help
=
"
Number of updates performed after each training step (default: 3)
"
)
help
=
"
Number of updates performed after each training step
"
)
parser
.
add_argument
(
"
--test_episodes
"
,
action
=
"
store
"
,
type
=
int
,
default
=
32
,
metavar
=
"
N
"
,
train_test_args
.
add_argument
(
"
--test_episodes
"
,
action
=
"
store
"
,
type
=
int
,
default
=
32
,
metavar
=
"
N
"
,
help
=
"
Number of episodes used for testing (default: 32)
"
)
help
=
"
Number of episodes used for testing
"
)
parser
.
add_argument
(
"
--total_steps
"
,
action
=
"
store
"
,
type
=
int
,
default
=
25_000_000
,
metavar
=
"
N
"
,
train_test_args
.
add_argument
(
"
--total_train_steps
"
,
action
=
"
store
"
,
type
=
int
,
default
=
25_000_000
,
metavar
=
"
N
"
,
help
=
"
Total number of steps until training is finished (default: 25_000_000)
"
)
help
=
"
Total number of steps until training is finished
"
)
parser
.
add_argument
(
"
--batch_size
"
,
action
=
"
store
"
,
type
=
int
,
default
=
1024
,
metavar
=
"
N
"
,
train_test_args
.
add_argument
(
"
--batch_size
"
,
action
=
"
store
"
,
type
=
int
,
default
=
1024
,
metavar
=
"
N
"
,
help
=
"
Batch size used for training (default: 1024)
"
)
help
=
"
Batch size used for training
"
)
parser
.
add_argument
(
"
--tau
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.05
,
metavar
=
"
N
"
,
train_test_args
.
add_argument
(
"
--tau
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.05
,
metavar
=
"
N
"
,
help
=
"
Factor used in soft update of target network (default: 0.05)
"
)
help
=
"
Factor used in soft update of target network
"
)
parser
.
add_argument
(
"
--shielded_action_sampling
"
,
action
=
"
store_true
"
,
default
=
False
,
help
=
"
Sample shielded actions when performing parameter updates (default: False)
"
)
# buffer args
# buffer args
parser
.
add_argument
(
"
--buffer_capacity
"
,
action
=
"
store
"
,
type
=
int
,
default
=
50_000
,
metavar
=
"
N
"
,
buffer_args
=
parser
.
add_argument_group
(
'
Buffer
'
)
help
=
"
Define the maximum capacity of the replay buffer (default: 50_000)
"
)
buffer_args
.
add_argument
(
"
--buffer_capacity
"
,
action
=
"
store
"
,
type
=
int
,
default
=
50_000
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--clear_buffer
"
,
action
=
"
store_true
"
,
default
=
False
,
help
=
"
Define the maximum capacity of the replay buffer
"
)
help
=
"
Clear Replay Buffer after every optimization step (default: False)
"
)
buffer_args
.
add_argument
(
"
--clear_buffer
"
,
action
=
"
store_true
"
,
default
=
False
,
help
=
"
Clear Replay Buffer after every optimization step
"
)
# csc args
# csc args
parser
.
add_argument
(
"
--shield_iterations
"
,
action
=
"
store
"
,
type
=
int
,
default
=
100
,
metavar
=
"
N
"
,
csc_args
=
parser
.
add_argument_group
(
'
Agent
'
)
help
=
"
Maximum number of actions sampled during shielding (default: 100)
"
)
csc_args
.
add_argument
(
"
--shield_iterations
"
,
action
=
"
store
"
,
type
=
int
,
default
=
100
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--line_search_iterations
"
,
action
=
"
store
"
,
type
=
int
,
default
=
20
,
metavar
=
"
N
"
,
help
=
"
Maximum number of actions sampled during shielding
"
)
help
=
"
Maximum number of line search update iterations (default: 20)
"
)
csc_args
.
add_argument
(
"
--line_search_iterations
"
,
action
=
"
store
"
,
type
=
int
,
default
=
20
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--expectation_estimation_samples
"
,
action
=
"
store
"
,
type
=
int
,
default
=
20
,
metavar
=
"
N
"
,
help
=
"
Maximum number of line search update iterations
"
)
help
=
"
Number of samples to estimate expectations (default: 20)
"
)
csc_args
.
add_argument
(
"
--expectation_estimation_samples
"
,
action
=
"
store
"
,
type
=
int
,
default
=
20
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_chi
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.05
,
metavar
=
"
N
"
,
help
=
"
Number of samples to estimate expectations
"
)
help
=
"
Set the value of chi (default: 0.05)
"
)
csc_args
.
add_argument
(
"
--csc_chi
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.05
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_delta
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.01
,
metavar
=
"
N
"
,
help
=
"
Set the value of chi
"
)
help
=
"
Set the value of delta (default: 0.01)
"
)
csc_args
.
add_argument
(
"
--csc_delta
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.01
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_gamma
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.99
,
metavar
=
"
N
"
,
help
=
"
Set the value of delta
"
)
help
=
"
Set the value of gamma (default: 0.99)
"
)
csc_args
.
add_argument
(
"
--csc_gamma
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.99
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_beta
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.7
,
metavar
=
"
N
"
,
help
=
"
Set the value of gamma
"
)
help
=
"
Set the value of beta (default: 0.7)
"
)
csc_args
.
add_argument
(
"
--csc_beta
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.7
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_alpha
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.5
,
metavar
=
"
N
"
,
help
=
"
Set the value of beta
"
)
help
=
"
Set the value of alpha (default: 0.5)
"
)
csc_args
.
add_argument
(
"
--csc_alpha
"
,
action
=
"
store
"
,
type
=
float
,
default
=
0.5
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_lambda
"
,
action
=
"
store
"
,
type
=
float
,
default
=
1.0
,
metavar
=
"
N
"
,
help
=
"
Set the value of alpha
"
)
help
=
"
Set the initial value of lambda (default: 1.0)
"
)
csc_args
.
add_argument
(
"
--csc_lambda
"
,
action
=
"
store
"
,
type
=
float
,
default
=
1.0
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_safety_critic_lr
"
,
action
=
"
store
"
,
type
=
float
,
default
=
2e-4
,
metavar
=
"
N
"
,
help
=
"
Set the initial value of lambda
"
)
help
=
"
Learn rate for the safety critic (default: 2e-4)
"
)
csc_args
.
add_argument
(
"
--csc_safety_critic_lr
"
,
action
=
"
store
"
,
type
=
float
,
default
=
2e-4
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_value_network_lr
"
,
action
=
"
store
"
,
type
=
float
,
default
=
1e-3
,
metavar
=
"
N
"
,
help
=
"
Learn rate for the safety critic
"
)
help
=
"
Learn rate for the value network (default: 1e-3)
"
)
csc_args
.
add_argument
(
"
--csc_value_network_lr
"
,
action
=
"
store
"
,
type
=
float
,
default
=
1e-3
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--csc_lambda_lr
"
,
action
=
"
store
"
,
type
=
float
,
default
=
4e-2
,
metavar
=
"
N
"
,
help
=
"
Learn rate for the value network
"
)
help
=
"
Learn rate for the lambda dual variable (default: 4e-2)
"
)
csc_args
.
add_argument
(
"
--csc_lambda_lr
"
,
action
=
"
store
"
,
type
=
float
,
default
=
4e-2
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--hidden_dim
"
,
action
=
"
store
"
,
type
=
int
,
default
=
32
,
metavar
=
"
N
"
,
help
=
"
Learn rate for the lambda dual variable
"
)
help
=
"
Hidden dimension of the networks (default: 32)
"
)
csc_args
.
add_argument
(
"
--hidden_dim
"
,
action
=
"
store
"
,
type
=
int
,
default
=
32
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--sigmoid_activation
"
,
action
=
"
store_true
"
,
default
=
False
,
help
=
"
Hidden dimension of the networks
"
)
help
=
"
Apply sigmoid activation to the safety critics output (default: False)
"
)
csc_args
.
add_argument
(
"
--sigmoid_activation
"
,
action
=
"
store_true
"
,
default
=
False
,
help
=
"
Apply sigmoid activation to the safety critics output
"
)
csc_args
.
add_argument
(
"
--shielded_action_sampling
"
,
action
=
"
store_true
"
,
default
=
False
,
help
=
"
Sample shielded actions when performing parameter updates
"
)
# common args
# common args
parser
.
add_argument
(
"
--seed
"
,
action
=
"
store
"
,
type
=
int
,
default
=
42
,
metavar
=
"
N
"
,
common_args
=
parser
.
add_argument_group
(
'
Common
'
)
help
=
"
Set a custom seed for the rng (default: 42)
"
)
common_args
.
add_argument
(
"
--seed
"
,
action
=
"
store
"
,
type
=
int
,
default
=
42
,
metavar
=
"
N
"
,
parser
.
add_argument
(
"
--device
"
,
action
=
"
store
"
,
type
=
str
,
default
=
"
cuda
"
,
metavar
=
"
DEVICE
"
,
help
=
"
Set a custom seed for the rng
"
)
help
=
"
Set the device for pytorch to use (default: cuda)
"
)
common_args
.
add_argument
(
"
--device
"
,
action
=
"
store
"
,
type
=
str
,
default
=
"
cuda
"
,
metavar
=
"
DEVICE
"
,
parser
.
add_argument
(
"
--log_dir
"
,
action
=
"
store
"
,
type
=
str
,
default
=
"
./runs
"
,
metavar
=
"
PATH
"
,
help
=
"
Set the device for pytorch to use
"
)
help
=
"
Set the output and log directory path (default: ./runs)
"
)
common_args
.
add_argument
(
"
--log_dir
"
,
action
=
"
store
"
,
type
=
str
,
default
=
"
./runs
"
,
metavar
=
"
PATH
"
,
parser
.
add_argument
(
"
--num_threads
"
,
action
=
"
store
"
,
type
=
int
,
default
=
1
,
metavar
=
"
N
"
,
help
=
"
Set the output and log directory path
"
)
help
=
"
Set the maximum number of threads for pytorch and numpy (default: 1)
"
)
common_args
.
add_argument
(
"
--num_threads
"
,
action
=
"
store
"
,
type
=
int
,
default
=
1
,
metavar
=
"
N
"
,
help
=
"
Set the maximum number of threads for pytorch and numpy
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
...
@@ -98,147 +101,163 @@ def cmd_args():
...
@@ -98,147 +101,163 @@ def cmd_args():
##################
##################
def
setup
(
args
):
def
setup
(
args
):
"""
Performs setup like fixing seeds, initializing env and agent, buffer and stats.
"""
torch
.
set_num_threads
(
args
.
num_threads
)
torch
.
set_num_threads
(
args
.
num_threads
)
os
.
environ
[
"
OMP_NUM_THREADS
"
]
=
str
(
args
.
num_threads
)
# export OMP_NUM_THREADS=args.num_threads
os
.
environ
[
"
OPENBLAS_NUM_THREADS
"
]
=
str
(
args
.
num_threads
)
# export OPENBLAS_NUM_THREADS=args.num_threads
os
.
environ
[
"
MKL_NUM_THREADS
"
]
=
str
(
args
.
num_threads
)
# export MKL_NUM_THREADS=args.num_threads
os
.
environ
[
"
VECLIB_MAXIMUM_THREADS
"
]
=
str
(
args
.
num_threads
)
# export VECLIB_MAXIMUM_THREADS=args.num_threads
os
.
environ
[
"
NUMEXPR_NUM_THREADS
"
]
=
str
(
args
.
num_threads
)
# export NUMEXPR_NUM_THREADS=args.num_threads
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
torch
.
set_default_dtype
(
torch
.
float64
)
output_dir
=
os
.
path
.
join
(
args
.
log_dir
,
datetime
.
datetime
.
now
().
strftime
(
"
%d_%m_%y__%H_%M_%S
"
))
output_dir
=
os
.
path
.
join
(
args
.
log_dir
,
datetime
.
datetime
.
now
().
strftime
(
"
%d_%m_%y__%H_%M_%S
"
))
writer
=
SummaryWriter
(
log_dir
=
output_dir
)
writer
=
SummaryWriter
(
log_dir
=
output_dir
)
with
open
(
os
.
path
.
join
(
output_dir
,
"
config.json
"
),
"
w
"
)
as
file
:
with
open
(
os
.
path
.
join
(
output_dir
,
"
config.json
"
),
"
w
"
)
as
file
:
json
.
dump
(
args
.
__dict__
,
file
,
indent
=
2
)
json
.
dump
(
args
.
__dict__
,
file
,
indent
=
2
)
env
=
safety_gymnasium
.
vector
.
make
(
env_id
=
args
.
env_id
,
num_envs
=
args
.
num_vectorized_envs
,
asynchronous
=
Tru
e
)
env
=
env
=
safety_gymnasium
.
vector
.
make
(
env_id
=
args
.
env_id
,
num_envs
=
args
.
num_vectorized_envs
,
asynchronous
=
Fals
e
)
buffer
=
ReplayBuffer
(
env
=
env
,
cap
=
args
.
buffer_capacity
)
buffer
=
ReplayBuffer
(
env
=
env
,
cap
=
args
.
buffer_capacity
)
agent
=
CSCAgent
(
env
,
args
,
writer
)
stats
=
Statistics
(
writer
)
agent
=
CSCAgent
(
env
,
args
,
buffer
,
stats
)
stats
=
Statistics
()
return
env
,
agent
,
buffer
,
stats
return
env
,
agent
,
buffer
,
writer
,
stats
##################
##################
# EXPLORATION
# EXPLORATION
##################
##################
@torch.no_grad
@torch.no_grad
def
run_vectorized_exploration
(
args
,
env
:
safety_gymnasium
.
vector
.
VectorEnv
,
agent
,
buffer
,
stats
,
train
,
shielded
):
def
run_vectorized_exploration
(
args
,
env
,
agent
,
buffer
,
stats
,
train
=
True
,
shielded
=
True
):
# track currently running and leftover episodes
avg_steps
=
0
open_episodes
=
args
.
train_episodes
if
train
else
args
.
test_episodes
avg_reward
=
0
running_episodes
=
args
.
num_vectorized_envs
avg_cost
=
0
avg_failures
=
0
# initialize mask and stats per episode
mask
=
np
.
ones
(
args
.
num_vectorized_envs
,
dtype
=
np
.
bool_
)
episode_count
=
args
.
num_vectorized_envs
episode_steps
=
np
.
zeros_like
(
mask
,
dtype
=
np
.
uint64
)
finished_count
=
0
episode_reward
=
np
.
zeros_like
(
mask
,
dtype
=
np
.
float64
)
num_episodes
=
args
.
train_episodes
if
train
else
args
.
test_episodes
episode_cost
=
np
.
zeros_like
(
mask
,
dtype
=
np
.
float64
)
mask
=
np
.
ones
(
args
.
num_vectorized_envs
,
dtype
=
'
bool
'
)
# adjust mask in case we have fewer runs than environments
episode_steps
=
np
.
zeros
(
args
.
num_vectorized_envs
)
if
open_episodes
<
args
.
num_vectorized_envs
:
episode_reward
=
np
.
zeros
(
args
.
num_vectorized_envs
)
mask
[
open_episodes
:]
=
False
episode_cost
=
np
.
zeros
(
args
.
num_vectorized_envs
)
running_episodes
=
open_episodes
open_episodes
-=
running_episodes
if
num_episodes
<
args
.
num_vectorized_envs
:
mask
[
num_episodes
:]
=
False
episode_count
=
num_episodes
state
,
info
=
env
.
reset
(
seed
=
random
.
randint
(
0
,
2
**
31
-
1
))
state
,
info
=
env
.
reset
(
seed
=
random
.
randint
(
0
,
2
**
31
-
1
))
while
finished_count
<
num_episodes
:
action
=
agent
.
sample
(
state
,
shielded
=
shielded
)
while
running_episodes
>
0
:
# sample and execute actions
next_state
,
reward
,
cost
,
terminated
,
truncated
,
info
=
env
.
step
(
action
)
with
torch
.
no_grad
():
actions
=
agent
.
sample
(
state
,
shielded
=
shielded
).
cpu
().
numpy
()
next_state
,
reward
,
cost
,
terminated
,
truncated
,
info
=
env
.
step
(
actions
)
done
=
terminated
|
truncated
done
=
terminated
|
truncated
not_done_masked
=
((
~
done
)
&
mask
)
done_masked
=
done
&
mask
done_masked_count
=
done_masked
.
sum
()
# increment stats
episode_steps
[
mask
]
+=
1
episode_steps
[
mask
]
+=
1
episode_reward
[
mask
]
+=
reward
[
mask
]
episode_reward
[
mask
]
+=
reward
[
mask
]
episode_cost
[
mask
]
+=
cost
[
mask
]
episode_cost
[
mask
]
+=
cost
[
mask
]
is_mask_zero
=
~
mask
.
any
()
# if any run has finished, we need to take special care
# 1. train: extract final_observation from info dict (single envs autoreset, no manual reset needed) and add to buffer
if
done
.
any
()
or
is_mask_zero
:
# 2. train+test: log episode stats using Statistics class
avg_steps
+=
episode_steps
.
sum
()
# 3. train+test: reset episode stats
avg_reward
+=
episode_reward
.
sum
()
# 4. train+test: adjust mask (if necessary)
avg_cost
+=
episode_cost
.
sum
()
if
done_masked_count
>
0
:
avg_failures
+=
(
episode_cost
>=
args
.
cost_limit
).
sum
()
if
train
:
if
train
:
stats
.
total_episodes
+=
episode_count
-
finished_count
# add experiences to buffer
stats
.
total_steps
+=
episode_steps
.
sum
()
buffer
.
add
(
stats
.
total_failures
+=
(
episode_cost
>=
args
.
cost_limit
).
sum
()
state
[
not_done_masked
],
actions
[
not_done_masked
],
if
not
is_mask_zero
:
reward
[
not_done_masked
],
buffer
.
add
(
state
[
mask
],
action
[
mask
],
reward
[
mask
],
cost
[
mask
],
np
.
stack
(
info
[
'
final_observation
'
],
axis
=
0
)[
mask
])
cost
[
not_done_masked
],
next_state
[
not_done_masked
],
mask
=
np
.
ones
(
args
.
num_vectorized_envs
,
dtype
=
'
bool
'
)
done
[
not_done_masked
]
episode_steps
=
np
.
zeros
(
args
.
num_vectorized_envs
)
)
episode_reward
=
np
.
zeros
(
args
.
num_vectorized_envs
)
buffer
.
add
(
episode_cost
=
np
.
zeros
(
args
.
num_vectorized_envs
)
state
[
done_masked
],
state
,
_
=
env
.
reset
()
# auto resets, but is_mask_zero requires us to reset
actions
[
done_masked
],
reward
[
done_masked
],
finished_count
=
episode_count
cost
[
done_masked
],
open_episodes
=
num_episodes
-
episode_count
np
.
stack
(
info
[
'
final_observation
'
],
axis
=
0
)[
done_masked
],
idx
=
min
(
open_episodes
,
args
.
num_vectorized_envs
)
done
[
done_masked
]
mask
[
idx
:]
=
False
)
episode_count
+=
idx
# record finished episodes
stats
.
record_train
(
num_episodes
=
done_masked_count
,
returns
=
episode_reward
[
done_masked
],
costs
=
episode_cost
[
done_masked
],
steps
=
episode_steps
[
done_masked
],
unsafe
=
(
episode_cost
[
done_masked
]
>
args
.
cost_limit
).
astype
(
np
.
uint8
)
)
stats
.
total_train_episodes
+=
done_masked_count
stats
.
total_train_steps
+=
episode_steps
[
done_masked
].
sum
()
stats
.
total_train_unsafe
+=
(
episode_cost
[
done_masked
]
>
args
.
cost_limit
).
sum
()
else
:
# record finished episodes
# stats module performs averaging over all episodes
stats
.
record_test
(
shielded
=
shielded
,
avg_returns
=
episode_reward
[
done_masked
],
avg_costs
=
episode_cost
[
done_masked
],
avg_steps
=
episode_steps
[
done_masked
],
avg_unsafe
=
(
episode_cost
[
done_masked
]
>
args
.
cost_limit
).
astype
(
np
.
uint8
)
)
# reset episode stats
state
=
next_state
episode_steps
[
done_masked
]
=
0
episode_reward
[
done_masked
]
=
0
episode_cost
[
done_masked
]
=
0
# adjust mask, running and open episodes counter
if
open_episodes
<
done_masked_count
:
# fewer left than just finished
done_masked_idxs
=
done_masked
.
nonzero
()[
0
]
mask
[
done_masked_idxs
[
open_episodes
:]]
=
False
running_episodes
-=
(
done_masked_count
-
open_episodes
)
open_episodes
=
0
else
:
# at least as many left than just finished
open_episodes
-=
done_masked_count
# no run has finished, just record experiences (if training)
else
:
else
:
if
train
:
if
train
:
buffer
.
add
(
state
[
mask
],
action
[
mask
],
reward
[
mask
],
cost
[
mask
],
next_state
[
mask
])
buffer
.
add
(
state
[
mask
],
actions
[
mask
],
reward
[
mask
],
cost
[
mask
],
next_state
[
mask
],
done
[
mask
])
if
args
.
enforce_cost_limit
:
# we dont care about the cost limit while testing
mask
=
mask
&
(
episode_cost
<
args
.
cost_limit
)
state
=
next_state
state
=
next_state
avg_steps
/=
num_episodes
# after exploration, flush stats
avg_reward
/=
num_episodes
stats
.
after_exploration
(
train
,
shielded
)
avg_cost
/=
num_episodes
avg_failures
/=
num_episodes
return
avg_steps
,
avg_reward
,
avg_cost
,
avg_failures
##################
##################
# MAIN LOOP
# MAIN LOOP
##################
##################
def
main
(
args
,
env
,
agent
,
buffer
,
writer
,
stats
:
Statistic
s
):
def
main
(
args
,
env
,
agent
,
buffer
,
stat
s
):
finished
=
False
finished
=
False
while
not
finished
:
while
not
finished
:
# Training + Update Loop
for
_
in
range
(
args
.
train_until_test
):
for
_
in
range
(
args
.
train_until_test
):
if
stats
.
total_steps
>=
args
.
total_steps
:
finished
=
True
# 1. Run exploration for training
break
run_vectorized_exploration
(
args
,
env
,
agent
,
buffer
,
stats
,
train
=
True
,
shielded
=
True
)
stats
.
begin
(
name
=
"
train
"
)
avg_steps
,
avg_reward
,
avg_cost
,
avg_failures
=
run_vectorized_exploration
(
args
,
env
,
agent
,
buffer
,
stats
,
train
=
True
,
shielded
=
True
)
# 2. Perform updates
print
(
f
"
[TRAIN] avg_steps:
{
round
(
avg_steps
,
4
)
}
, avg_reward:
{
round
(
avg_reward
,
4
)
}
, avg_cost:
{
round
(
avg_cost
,
4
)
}
, avg_failures:
{
round
(
avg_failures
,
4
)
}
"
)
for
_
in
range
(
args
.
update_iterations
):
stats
.
end
(
name
=
"
train
"
)
agent
.
update
()
stats
.
begin
(
name
=
"
update
"
)
for
i
in
range
(
args
.
update_iterations
):
# 3. After update stuff
stats
.
total_updates
+=
1
agent
.
update
(
buffer
=
buffer
,
avg_failures
=
avg_failures
,
total_episodes
=
stats
.
total_episodes
+
i
)
stats
.
end
(
name
=
"
update
"
)
agent
.
after_updates
()
agent
.
after_updates
()
if
args
.
clear_buffer
:
if
args
.
clear_buffer
:
buffer
.
clear
()
buffer
.
clear
()
stats
.
begin
(
name
=
"
test
"
)
# Test loop (shielded and unshielded)
for
shielded
,
postfix
in
zip
([
True
,
False
],
[
"
shielded
"
,
"
unshielded
"
]):
for
shielded
in
[
True
,
False
]:
avg_steps
,
avg_reward
,
avg_cost
,
avg_failures
=
run_vectorized_exploration
(
args
,
env
,
agent
,
buffer
,
stats
,
train
=
False
,
shielded
=
shielded
)
run_vectorized_exploration
(
args
,
env
,
agent
,
buffer
,
train
=
False
,
shielded
=
shielded
)
writer
.
add_scalar
(
f
"
test/avg_reward_
{
postfix
}
"
,
avg_reward
,
stats
.
total_episodes
)
writer
.
add_scalar
(
f
"
test/avg_cost_
{
postfix
}
"
,
avg_cost
,
stats
.
total_episodes
)
writer
.
add_scalar
(
f
"
test/avg_failures_
{
postfix
}
"
,
avg_failures
,
stats
.
total_episodes
)
print
(
f
"
[TEST_
{
postfix
.
upper
()
}
] avg_steps:
{
round
(
avg_steps
,
4
)
}
, avg_reward:
{
round
(
avg_reward
,
4
)
}
, avg_cost:
{
round
(
avg_cost
,
4
)
}
, avg_failures:
{
round
(
avg_failures
,
4
)
}
"
)
stats
.
end
(
name
=
"
test
"
)
stats
.
print
()
writer
.
flush
()
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
args
=
cmd_args
()
args
=
cmd_args
()
env
,
agent
,
buffer
,
writer
,
stats
=
setup
(
args
)
main
(
args
,
*
setup
(
args
))
main
(
args
,
env
,
agent
,
buffer
,
writer
,
stats
)
\ No newline at end of file
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment