Skip to content

Commit

Permalink
remove max_std (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
Clyde21c committed Jul 25, 2022
1 parent 402e777 commit f25fd7d
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/meta_rl/maml/algorithm/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
is_deterministic: bool = False,
init_std: float = 1.0,
min_std: float = 1e-6,
max_std: float = None,
) -> None:
super().__init__(
input_dim=input_dim,
Expand All @@ -63,14 +62,13 @@ def __init__(

self.log_std = torch.Tensor([init_std]).log()
self.log_std = torch.nn.Parameter(self.log_std)
self.min_log_std = torch.Tensor([min_std]).log().item() if min_std is not None else None
self.max_log_std = torch.Tensor([max_std]).log().item() if max_std is not None else None
self.min_log_std = torch.Tensor([min_std]).log().item()

self.is_deterministic = is_deterministic

def get_normal_dist(self, x: torch.Tensor) -> Tuple[Normal, torch.Tensor]:
mean = super().forward(x)
std = torch.exp(self.log_std.clamp(min=self.min_log_std, max=self.max_log_std))
std = torch.exp(self.log_std.clamp(min=self.min_log_std))
return Normal(mean, std), mean

def get_log_prob(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit f25fd7d

Please sign in to comment.