Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 17, 2024
1 parent e538fdc commit e815532
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
16 changes: 16 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device):
assert (a <= d.max).all()
lp = d.log_prob(a)
assert torch.isfinite(lp).all()
assert not torch.isfinite(
d.log_prob(
torch.as_tensor(d.min, device=device).expand(
(*d.batch_shape, *d.event_shape)
)
- 1e-2
)
).any()
assert not torch.isfinite(
d.log_prob(
torch.as_tensor(d.max, device=device).expand(
(*d.batch_shape, *d.event_shape)
)
+ 1e-2
)
).any()

def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device):
torch.manual_seed(0)
Expand Down
14 changes: 13 additions & 1 deletion torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,25 @@ def mode(self):
return torch.max(torch.stack([m, a], -1), dim=-1)[0]

def log_prob(self, value, **kwargs):
above_or_below = (self.min > value) | (self.max < value)
a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0
a = a.expand_as(value)
b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0
b = b.expand_as(value)
value = torch.min(torch.stack([value, b], -1), dim=-1)[0]
value = torch.max(torch.stack([value, a], -1), dim=-1)[0]
return super().log_prob(value, **kwargs)
lp = super().log_prob(value, **kwargs)
if above_or_below.any():
if self.event_shape:
above_or_below = above_or_below.flatten(-len(self.event_shape), -1).any(
-1
)
lp = torch.masked_fill(
lp,
above_or_below.expand_as(lp),
torch.tensor(-float("inf"), device=lp.device, dtype=lp.dtype),
)
return lp


class TanhNormal(FasterTransformedDistribution):
Expand Down

0 comments on commit e815532

Please sign in to comment.