Skip to content

Commit

Permalink
DQN: Make the soft_update_model_params method static
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessio committed May 23, 2024
1 parent 52d37a3 commit e2fe543
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions 03_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@
"\n",
" # TODO: Minimize the loss. Hint: zero_grad the optim, backward on loss, step the optim. \n",
"\n",
" @staticmethod\n",
" def soft_update_model_params(src: torch.nn.Module, dest: torch.nn.Module, tau=1e-3):\n",
" \"\"\"Soft updates model parameters (θ_dest = τ * θ_src + (1 - τ) * θ_src).\"\"\"\n",
" # TODO: For each dest parameter (get them via the parameters() function), update it with\n",
Expand Down
5 changes: 3 additions & 2 deletions solution/03_DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@
" self.learn()\n",
" self.t_update_target_step = (self.t_update_target_step + 1) % self.update_target_every\n",
" if self.t_update_target_step == 0:\n",
" self.soft_update_model_params(self.qnetwork_local, self.qnetwork_target, self.tau)\n",
" Agent.soft_update_model_params(self.qnetwork_local, self.qnetwork_target, self.tau)\n",
"\n",
" def act(self, state: np.array, eps=0.):\n",
" \"\"\"Makes the agent take an action for the state passed as input.\"\"\"\n",
Expand Down Expand Up @@ -480,7 +480,8 @@
" loss.backward()\n",
" self.optimizer.step() \n",
"\n",
" def soft_update_model_params(self, src: torch.nn.Module, dest: torch.nn.Module, tau=1e-3):\n",
" @staticmethod\n",
" def soft_update_model_params(src: torch.nn.Module, dest: torch.nn.Module, tau=1e-3):\n",
" \"\"\"Soft updates model parameters (θ_dest = τ * θ_src + (1 - τ) * θ_src).\"\"\"\n",
" for dest_param, src_param in zip(dest.parameters(), src.parameters()):\n",
" dest_param.data.copy_(tau * src_param.data + (1.0 - tau) * dest_param.data)\n",
Expand Down

0 comments on commit e2fe543

Please sign in to comment.