From a18f2dd655da225711a60c7c4d2720050de2c13c Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Wed, 2 Oct 2024 11:42:33 +0200
Subject: [PATCH] Added os environ calls

---
 main.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/main.py b/main.py
index d72336e..82f260d 100644
--- a/main.py
+++ b/main.py
@@ -84,8 +84,8 @@ def cmd_args():
                         help="Set the device for pytorch to use (default: cuda)")
     parser.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH",
                         help="Set the output and log directory path (default: ./runs)")
-    parser.add_argument("--num_threads", action="store", type=int, default=32, metavar="N",
-                        help="Set the maximum number of threads for pytorch (default: 32)")
+    parser.add_argument("--num_threads", action="store", type=int, default=1, metavar="N",
+                        help="Set the maximum number of threads for pytorch and numpy (default: 1)")
     args = parser.parse_args()
     return args
 
@@ -96,6 +96,12 @@ def cmd_args():
 def setup(args):
     torch.set_num_threads(args.num_threads)
 
+    os.environ["OMP_NUM_THREADS"] = str(args.num_threads) # export OMP_NUM_THREADS=args.num_threads
+    os.environ["OPENBLAS_NUM_THREADS"] = str(args.num_threads) # export OPENBLAS_NUM_THREADS=args.num_threads
+    os.environ["MKL_NUM_THREADS"] = str(args.num_threads) # export MKL_NUM_THREADS=args.num_threads
+    os.environ["VECLIB_MAXIMUM_THREADS"] = str(args.num_threads) # export VECLIB_MAXIMUM_THREADS=args.num_threads
+    os.environ["NUMEXPR_NUM_THREADS"] = str(args.num_threads) # export NUMEXPR_NUM_THREADS=args.num_threads
+
     random.seed(args.seed)
     np.random.seed(args.seed)
     torch.manual_seed(args.seed)
-- 
GitLab