From b0559c966f15dfcb53107767a5b7fd9a3fcfb68a Mon Sep 17 00:00:00 2001
From: Phil <s8phsaue@stud.uni-saarland.de>
Date: Mon, 30 Sep 2024 15:06:02 +0200
Subject: [PATCH] removed set_default_device

---
 main.py       |  1 -
 src/policy.py | 21 ++++++++++-----------
 2 files changed, 10 insertions(+), 12 deletions(-)

diff --git a/main.py b/main.py
index 943b15c..259f40f 100644
--- a/main.py
+++ b/main.py
@@ -98,7 +98,6 @@ def setup(args):
     np.random.seed(args.seed)
     torch.manual_seed(args.seed)
     torch.set_default_dtype(torch.float64)
-    torch.set_default_device(args.device)
 
     output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S"))
     writer = SummaryWriter(log_dir=output_dir)
diff --git a/src/policy.py b/src/policy.py
index ff47071..2ef02cd 100644
--- a/src/policy.py
+++ b/src/policy.py
@@ -119,15 +119,14 @@ class CSCAgent():
         def fisher_matrix(glog_prob):
             b, n = glog_prob.shape
             s,i = 32,0  # batch size, index
-            fisher = torch.zeros((n,n))
+            fisher = torch.zeros((n,n), device=self._device)
             while i < b:
                 g = glog_prob[i:i+s, ...]
-                fisher += (torch.func.vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b
+                fisher += (vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b
                 i += s
             return fisher
 
         glog_prob = log_prob_grad(states, actions).detach()
-
         with torch.no_grad():
             fisher = fisher_matrix(glog_prob)
             ahat = estimate_ahat(states, actions, rewards, next_states)
@@ -141,7 +140,7 @@ class CSCAgent():
             
             # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32, --hidden_dim 64 numpy is slow
             time_start = time.time()
-            gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0])
+            gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0], device=self._device)
             time_end = time.time()
             print(f"[DEBUG] Inverse took: {round(time_end - time_start, 2)}s")
 
@@ -202,11 +201,11 @@ class CSCAgent():
         self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures)
 
         states, actions, rewards, costs, next_states = buffer.sample(self._batch_size)
-        states = torch.tensor(states)
-        actions = torch.tensor(actions)
-        rewards = torch.tensor(rewards)
-        costs = torch.tensor(costs)
-        next_states = torch.tensor(next_states)
+        states = torch.tensor(states, device=self._device)
+        actions = torch.tensor(actions, device=self._device)
+        rewards = torch.tensor(rewards, device=self._device)
+        costs = torch.tensor(costs, device=self._device)
+        next_states = torch.tensor(next_states, device=self._device)
 
         piter = self._update_policy(states, actions, rewards, next_states)
         vloss = self._update_value_network(states, rewards, next_states)
@@ -224,7 +223,7 @@ class CSCAgent():
 
     @torch.no_grad
     def sample(self, state, shielded=True):
-        state = torch.tensor(state)
+        state = torch.tensor(state, device=self._device)
         if shielded:
             action = self._policy.sample(state, num_samples=self._shield_iterations)
 
@@ -236,7 +235,7 @@ class CSCAgent():
             is_zero = mask[torch.arange(0, state.shape[1]), argmax] == 0    # check if in a row every col is unsafe
 
             action = action.permute(1, 0, 2)
-            result = torch.zeros((state.shape[1], action.shape[-1]))
+            result = torch.zeros((state.shape[1], action.shape[-1]), device=self._device)
             result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...]     # minimum in case only unsafe actions
             result[~is_zero] = action[~is_zero, argmax[~is_zero], ...]    # else first safe action
 
-- 
GitLab