diff --git a/gnina/inference.py b/gnina/inference.py index d543ac9..1ae6e07 100644 --- a/gnina/inference.py +++ b/gnina/inference.py @@ -153,6 +153,9 @@ def inference(args): # Create model model = models.models_dict[(args.model, affinity)](test_loader.dims).to(device) + # Compile model with TorchScript + model = torch.jit.script(model) + # Load checkpoint checkpoint = torch.load(args.checkpoint, map_location=device) Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint) diff --git a/gnina/losses.py b/gnina/losses.py index 71943b4..8a67dfa 100644 --- a/gnina/losses.py +++ b/gnina/losses.py @@ -90,7 +90,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: if self.reduction == "mean": reduced_loss = torch.mean(loss) - elif self.reduction == "sum": + else: # Assertion in init ensures that reduction is "sum" reduced_loss = torch.sum(loss) return self.scale * reduced_loss diff --git a/gnina/training.py b/gnina/training.py index be4f32e..d22858b 100644 --- a/gnina/training.py +++ b/gnina/training.py @@ -589,8 +589,8 @@ def training(args): model = models_dict[(args.model, affinity)](train_loader.dims).to(device) model.apply(weights_and_biases_init) - # TODO: Compile model into TorchScript - # Requires model refactoring to avoid branching based on affinity + # Compile model into TorchScript + model = torch.jit.script(model) optimizer = optim.SGD( model.parameters(), @@ -600,13 +600,15 @@ def training(args): ) # Define loss functions - pose_loss = nn.NLLLoss() + pose_loss = torch.jit.script(nn.NLLLoss()) affinity_loss = ( - AffinityLoss( - delta=args.delta_affinity_loss, - penalty=args.penalty_affinity_loss, - pseudo_huber=args.pseudo_huber_affinity_loss, - scale=args.scale_affinity_loss, + torch.jit.script( + AffinityLoss( + delta=args.delta_affinity_loss, + penalty=args.penalty_affinity_loss, + pseudo_huber=args.pseudo_huber_affinity_loss, + scale=args.scale_affinity_loss, + ) ) if affinity else None