diff --git a/main.py b/main.py index d72336e07341a103e5448921a4a0fef59eed526b..82f260d2f5cde50fffb4f8842306e1a1ce2b3444 100644 --- a/main.py +++ b/main.py @@ -84,8 +84,8 @@ def cmd_args(): help="Set the device for pytorch to use (default: cuda)") parser.add_argument("--log_dir", action="store", type=str, default="./runs", metavar="PATH", help="Set the output and log directory path (default: ./runs)") - parser.add_argument("--num_threads", action="store", type=int, default=32, metavar="N", - help="Set the maximum number of threads for pytorch (default: 32)") + parser.add_argument("--num_threads", action="store", type=int, default=1, metavar="N", + help="Set the maximum number of threads for pytorch and numpy (default: 1)") args = parser.parse_args() return args @@ -96,6 +96,12 @@ def cmd_args(): def setup(args): torch.set_num_threads(args.num_threads) + os.environ["OMP_NUM_THREADS"] = str(args.num_threads) # export OMP_NUM_THREADS=args.num_threads + os.environ["OPENBLAS_NUM_THREADS"] = str(args.num_threads) # export OPENBLAS_NUM_THREADS=args.num_threads + os.environ["MKL_NUM_THREADS"] = str(args.num_threads) # export MKL_NUM_THREADS=args.num_threads + os.environ["VECLIB_MAXIMUM_THREADS"] = str(args.num_threads) # export VECLIB_MAXIMUM_THREADS=args.num_threads + os.environ["NUMEXPR_NUM_THREADS"] = str(args.num_threads) # export NUMEXPR_NUM_THREADS=args.num_threads + random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed)