From b0559c966f15dfcb53107767a5b7fd9a3fcfb68a Mon Sep 17 00:00:00 2001 From: Phil <s8phsaue@stud.uni-saarland.de> Date: Mon, 30 Sep 2024 15:06:02 +0200 Subject: [PATCH] removed set_default_device --- main.py | 1 - src/policy.py | 21 ++++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 943b15c..259f40f 100644 --- a/main.py +++ b/main.py @@ -98,7 +98,6 @@ def setup(args): np.random.seed(args.seed) torch.manual_seed(args.seed) torch.set_default_dtype(torch.float64) - torch.set_default_device(args.device) output_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%d_%m_%y__%H_%M_%S")) writer = SummaryWriter(log_dir=output_dir) diff --git a/src/policy.py b/src/policy.py index ff47071..2ef02cd 100644 --- a/src/policy.py +++ b/src/policy.py @@ -119,15 +119,14 @@ class CSCAgent(): def fisher_matrix(glog_prob): b, n = glog_prob.shape s,i = 32,0 # batch size, index - fisher = torch.zeros((n,n)) + fisher = torch.zeros((n,n), device=self._device) while i < b: g = glog_prob[i:i+s, ...] - fisher += (torch.func.vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b + fisher += (vmap(torch.outer, in_dims=(0,0))(g,g)).sum(dim=0)/b i += s return fisher glog_prob = log_prob_grad(states, actions).detach() - with torch.no_grad(): fisher = fisher_matrix(glog_prob) ahat = estimate_ahat(states, actions, rewards, next_states) @@ -141,7 +140,7 @@ class CSCAgent(): # NOTE: pytorch somehow is very slow at computing pseudoinverse on --hidden_dim 32, --hidden_dim 64 numpy is slow time_start = time.time() - gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0]) + gradient_term = torch.tensor(np.linalg.lstsq(fisher.cpu().numpy(), gJ.cpu().numpy(), rcond=None)[0], device=self._device) time_end = time.time() print(f"[DEBUG] Inverse took: {round(time_end - time_start, 2)}s") @@ -202,11 +201,11 @@ class CSCAgent(): self._unsafety_threshold = (1 - self._gamma) * (self._chi - avg_failures) states, actions, rewards, costs, next_states = buffer.sample(self._batch_size) - states = torch.tensor(states) - actions = torch.tensor(actions) - rewards = torch.tensor(rewards) - costs = torch.tensor(costs) - next_states = torch.tensor(next_states) + states = torch.tensor(states, device=self._device) + actions = torch.tensor(actions, device=self._device) + rewards = torch.tensor(rewards, device=self._device) + costs = torch.tensor(costs, device=self._device) + next_states = torch.tensor(next_states, device=self._device) piter = self._update_policy(states, actions, rewards, next_states) vloss = self._update_value_network(states, rewards, next_states) @@ -224,7 +223,7 @@ class CSCAgent(): @torch.no_grad def sample(self, state, shielded=True): - state = torch.tensor(state) + state = torch.tensor(state, device=self._device) if shielded: action = self._policy.sample(state, num_samples=self._shield_iterations) @@ -236,7 +235,7 @@ class CSCAgent(): is_zero = mask[torch.arange(0, state.shape[1]), argmax] == 0 # check if in a row every col is unsafe action = action.permute(1, 0, 2) - result = torch.zeros((state.shape[1], action.shape[-1])) + result = torch.zeros((state.shape[1], action.shape[-1]), device=self._device) result[is_zero] = action[is_zero, torch.argmin(unsafety[is_zero, ...], dim=1), ...] # minimum in case only unsafe actions result[~is_zero] = action[~is_zero, argmax[~is_zero], ...] # else first safe action -- GitLab