Skip to content

Commit

Permalink
Merge pull request #18 from RMeli/weightsandbiases
Browse files Browse the repository at this point in the history
Remove weight initialisation from model __init__
  • Loading branch information
RMeli authored Dec 2, 2021
2 parents 68ea7ce + cfbc41e commit 39985f3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 54 deletions.
75 changes: 22 additions & 53 deletions gnina/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@
from torch import nn


def weights_and_biases_init(m: nn.Module) -> None:
"""
Initialize the weights and biases of the model.
Parameters
----------
m : nn.Module
Module (layer) to initialize
Notes
-----
This function is used to initialize the weights of the model for both convolutional
and linear layers. Weights are initialized using uniform Xavier initialization
while biases are set to zero.
https://github.com/gnina/libmolgrid/blob/e6d5f36f1ae03f643ca69cdec1625ac52e653f88/test/test_torch_cnn.py#L45-L48
"""
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
nn.init.constant_(m.bias.data, 0.0)


class Default2017(nn.Module):
"""
GNINA default2017 model architecture.
Expand Down Expand Up @@ -134,13 +156,6 @@ def __init__(self, input_dims: Tuple):
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
# TODO: Initialize bias to zero?
# TODO: See https://github.com/gnina/libmolgrid/blob/e6d5f36f1ae03f643ca69cdec1625ac52e653f88/test/test_torch_cnn.py#L48

def forward(self, x: torch.Tensor):
"""
Parameters
Expand Down Expand Up @@ -197,13 +212,6 @@ def __init__(self, input_dims: Tuple):
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
# TODO: Initialize bias to zero?
# TODO: See https://github.com/gnina/libmolgrid/blob/e6d5f36f1ae03f643ca69cdec1625ac52e653f88/test/test_torch_cnn.py#L48

def forward(self, x: torch.Tensor):
"""
Parameters
Expand Down Expand Up @@ -368,13 +376,6 @@ def __init__(self, input_dims: Tuple):
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
# TODO: Initialize bias to zero?
# TODO: See https://github.com/gnina/libmolgrid/blob/e6d5f36f1ae03f643ca69cdec1625ac52e653f88/test/test_torch_cnn.py#L48

def forward(self, x: torch.Tensor):
"""
Parameters
Expand Down Expand Up @@ -430,13 +431,6 @@ def __init__(self, input_dims: Tuple):
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
# TODO: Initialize bias to zero?
# TODO: See https://github.com/gnina/libmolgrid/blob/e6d5f36f1ae03f643ca69cdec1625ac52e653f88/test/test_torch_cnn.py#L48

def forward(self, x: torch.Tensor):
"""
Parameters
Expand Down Expand Up @@ -656,11 +650,6 @@ def __init__(

self.features = nn.Sequential(features)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)

def forward(self, x):
"""
Parameters
Expand Down Expand Up @@ -719,11 +708,6 @@ def __init__(
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)

def forward(self, x):
"""
Parameters
Expand Down Expand Up @@ -795,11 +779,6 @@ def __init__(
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)

def forward(self, x):
"""
Parameters
Expand Down Expand Up @@ -927,11 +906,6 @@ def __init__(self, input_dims: Tuple):
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)

def forward(self, x: torch.Tensor):
"""
Parameters
Expand Down Expand Up @@ -1066,11 +1040,6 @@ def __init__(self, input_dims: Tuple):
)
)

# Xavier initialization for convolutional and linear layers
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)

def forward(self, x: torch.Tensor):
"""
Parameters
Expand Down
3 changes: 2 additions & 1 deletion gnina/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from gnina import metrics, setup, utils
from gnina.dataloaders import GriddedExamplesLoader
from gnina.losses import AffinityLoss
from gnina.models import models_dict
from gnina.models import models_dict, weights_and_biases_init


def options(args: Optional[List[str]] = None):
Expand Down Expand Up @@ -587,6 +587,7 @@ def training(args):
# Create model
# Select model based on architecture and affinity flag (pose vs affinity)
model = models_dict[(args.model, affinity)](train_loader.dims).to(device)
model.apply(weights_and_biases_init)

# TODO: Compile model into TorchScript
# Requires model refactoring to avoid branching based on affinity
Expand Down

0 comments on commit 39985f3

Please sign in to comment.