Skip to content

Commit

Permalink
Merge pull request #19 from RMeli/jit
Browse files Browse the repository at this point in the history
Compile models and losses with TorchScript for training and inference
  • Loading branch information
RMeli committed Dec 2, 2021
2 parents 39985f3 + d34c1e6 commit f8aecb6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
3 changes: 3 additions & 0 deletions gnina/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def inference(args):
# Create model
model = models.models_dict[(args.model, affinity)](test_loader.dims).to(device)

# Compile model with TorchScript
model = torch.jit.script(model)

# Load checkpoint
checkpoint = torch.load(args.checkpoint, map_location=device)
Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint)
Expand Down
2 changes: 1 addition & 1 deletion gnina/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:

if self.reduction == "mean":
reduced_loss = torch.mean(loss)
elif self.reduction == "sum":
else: # Assertion in init ensures that reduction is "sum"
reduced_loss = torch.sum(loss)

return self.scale * reduced_loss
18 changes: 10 additions & 8 deletions gnina/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,8 @@ def training(args):
model = models_dict[(args.model, affinity)](train_loader.dims).to(device)
model.apply(weights_and_biases_init)

# TODO: Compile model into TorchScript
# Requires model refactoring to avoid branching based on affinity
# Compile model into TorchScript
model = torch.jit.script(model)

optimizer = optim.SGD(
model.parameters(),
Expand All @@ -600,13 +600,15 @@ def training(args):
)

# Define loss functions
pose_loss = nn.NLLLoss()
pose_loss = torch.jit.script(nn.NLLLoss())
affinity_loss = (
AffinityLoss(
delta=args.delta_affinity_loss,
penalty=args.penalty_affinity_loss,
pseudo_huber=args.pseudo_huber_affinity_loss,
scale=args.scale_affinity_loss,
torch.jit.script(
AffinityLoss(
delta=args.delta_affinity_loss,
penalty=args.penalty_affinity_loss,
pseudo_huber=args.pseudo_huber_affinity_loss,
scale=args.scale_affinity_loss,
)
)
if affinity
else None
Expand Down

0 comments on commit f8aecb6

Please sign in to comment.