-
Notifications
You must be signed in to change notification settings - Fork 3
/
qnetwork.py
28 lines (24 loc) · 1.17 KB
/
qnetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch import nn
import torch.nn.functional as F
class Qnet(nn.Module):
'''
Deep convolutional network taking batch of stacks of frames as input
(batch_size x in_channels x height x width)
and returning vectors of size (n_actions, 1) and (embedding, 1).
Current layer parameters correspond to img size = 84x84.
'''
def __init__(self, n_actions, in_channels=4, embedding_size=256):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=8, stride=4, padding=0)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=0)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=0)
self.fc = nn.Linear(in_features=7*7*32, out_features=embedding_size)
self.output = nn.Linear(in_features=embedding_size, out_features=n_actions)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
embedding = self.fc(torch.flatten(x, start_dim=1))
x = self.output(F.relu(embedding))
return x, embedding