Skip to content

Commit

Permalink
Fix the soft_actor_critic model (#2326)
Browse files Browse the repository at this point in the history
Summary:
Unfortunately, #2318 has a bug that breaks the `soft_actor_critic` model.

Pull Request resolved: #2326

Reviewed By: aaronenyeshi

Differential Revision: D58871386

Pulled By: xuzhao9

fbshipit-source-id: 5f8b5fbe00722ccb647b08a8089fd52a7719208c
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jun 21, 2024
1 parent 1425f68 commit d910b8a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions torchbenchmark/models/soft_actor_critic/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import distributions as pyd
from torch import nn

from . import utils
from . import sac_utils
from torchbenchmark.util.distribution import SquashedNormal

def weight_init(m):
Expand All @@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50):
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
obs_shape[1:], kernel_size=(3, 3), stride=(2, 2)
)
for _ in range(3):
output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)

Expand Down Expand Up @@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50):
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
obs_shape[1:], kernel_size=(8, 8), stride=(4, 4)
)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(4, 4), stride=(2, 2)
)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)

Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/models/soft_actor_critic/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from . import envs, nets, replay, utils
from . import envs, nets, replay, sac_utils


class SACAgent:
Expand Down

0 comments on commit d910b8a

Please sign in to comment.