From a18f2dd655da225711a60c7c4d2720050de2c13c Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Wed, 2 Oct 2024 11:42:33 +0200 Subject: [PATCH] Added os environ calls --- main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index d72336e..82f260d 100644 --- a/main.py +++ b/main.py @@ -84,8 +84,8 @@ def cmd_args(): help="Set the device for pytorch to use (default: cuda)") parser.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH", help="Set the output and log directory path (default: ./runs)") - parser.add_argument("--num_threads", action="store", type=int, default=32, metavar="N", - help="Set the maximum number of threads for pytorch (default: 32)") + parser.add_argument("--num_threads", action="store", type=int, default=1, metavar="N", + help="Set the maximum number of threads for pytorch and numpy (default: 1)") args = parser.parse_args() return args @@ -96,6 +96,12 @@ def cmd_args(): def setup(args): torch.set_num_threads(args.num_threads) + os.environ["OMP_NUM_THREADS"] = str(args.num_threads) # export OMP_NUM_THREADS=args.num_threads + os.environ["OPENBLAS_NUM_THREADS"] = str(args.num_threads) # export OPENBLAS_NUM_THREADS=args.num_threads + os.environ["MKL_NUM_THREADS"] = str(args.num_threads) # export MKL_NUM_THREADS=args.num_threads + os.environ["VECLIB_MAXIMUM_THREADS"] = str(args.num_threads) # export VECLIB_MAXIMUM_THREADS=args.num_threads + os.environ["NUMEXPR_NUM_THREADS"] = str(args.num_threads) # export NUMEXPR_NUM_THREADS=args.num_threads + random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) -- GitLab