-
Notifications
You must be signed in to change notification settings - Fork 4
/
training.py
717 lines (620 loc) 路 20 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
"""
PyTorch implementation of GNINA scoring function's Caffe training script.
"""
import argparse
import os
import sys
from typing import List, Optional
import molgrid
import numpy as np
import torch
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint
from torch import nn, optim
from gnina import metrics, setup, utils
from gnina.dataloaders import GriddedExamplesLoader
from gnina.losses import AffinityLoss
from gnina.models import models_dict, weights_and_biases_init
def options(args: Optional[List[str]] = None):
"""
Define options and parse arguments.
Parameters
----------
args: Optional[List[str]]
List of command line arguments
"""
parser = argparse.ArgumentParser(
description="GNINA scoring function",
)
# Data
# TODO: Allow multiple train files?
parser.add_argument("trainfile", type=str, help="Training file")
parser.add_argument("--testfile", type=str, default=None, help="Test file")
parser.add_argument(
"-d",
"--data_root",
type=str,
default="",
help="Root folder for relative paths in train files",
)
parser.add_argument(
"--balanced", action="store_true", help="Balanced sampling of receptors"
)
parser.add_argument(
"--no_shuffle",
action="store_false",
help="Deactivate random shuffling of samples",
dest="shuffle", # Variable name (shuffle is False when --no_shuffle is used)
)
parser.add_argument(
"--label_pos", type=int, default=0, help="Pose label position in training file"
)
parser.add_argument(
"--affinity_pos",
type=int,
default=None,
help="Affinity value position in training file",
)
parser.add_argument(
"--stratify_receptor",
action="store_true",
help="Sample uniformly across receptors",
)
parser.add_argument(
"--ligmolcache",
type=str,
default="",
help=".molcache2 file for ligands",
)
parser.add_argument(
"--recmolcache",
type=str,
default="",
help=".molcache2 file for receptors",
)
parser.add_argument(
"-o", "--out_dir", type=str, default=os.getcwd(), help="Output directory"
)
# Scoring function
parser.add_argument(
"-m",
"--model",
type=str,
default="default2017",
help="Model name",
choices=[k[0] for k in models_dict.keys()], # Model names
)
parser.add_argument("--dimension", type=float, default=23.5, help="Grid dimension")
parser.add_argument("--resolution", type=float, default=0.5, help="Grid resolution")
# TODO: ligand type file and receptor type file (default: 28 types)
# Learning
parser.add_argument(
"--base_lr", type=float, default=0.01, help="Base (initial) learning rate"
)
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum")
parser.add_argument(
"--weight_decay", type=float, help="Weight decay", default=0.001
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument(
"--no_random_rotation",
action="store_false",
help="Deactivate random rotation of samples",
dest="random_rotation",
)
parser.add_argument(
"--random_translation", type=float, default=6.0, help="Random translation"
)
parser.add_argument(
"-i",
"--iterations",
type=int,
default=250000,
help="Number of iterations (epochs)",
)
parser.add_argument(
"--iteration_scheme",
type=str,
default="small",
help="molgrid iteration scheme",
choices=setup._iteration_schemes.keys(),
)
# lr_dynamic, originally called --dynamic
parser.add_argument(
"--lr_dynamic",
action="store_true",
help="Adjust learning rate in response to training",
)
# lr_patience, originally called --step_when
# Acts on epochs, not on iterations
parser.add_argument(
"--lr_patience",
type=int,
default=5,
help="Number of epochs without improvement before learning rate update",
)
# lr_reduce, originally called --step_reduce
parser.add_argument(
"--lr_reduce", type=float, default=0.1, help="Learning rate reduction factor"
)
# lr_min default value set to match --step_end_cnt default value (3 reductions)
parser.add_argument("--lr_min", type=float, default=0.01 * 0.1 ** 3)
parser.add_argument(
"--clip_gradients",
type=float,
default=10.0,
help="Gradients threshold (for clipping)",
)
parser.add_argument(
"--pseudo_huber_affinity_loss",
action="store_true",
help="Use pseudo-Huber loss for affinity loss",
)
parser.add_argument(
"--delta_affinity_loss",
type=float,
default=4.0,
help="Delta factor for affinity loss",
)
parser.add_argument(
"--scale_affinity_loss",
type=float,
default=1.0,
help="Scale factor for affinity loss",
)
parser.add_argument(
"--penalty_affinity_loss",
type=float,
default=1.0,
help="Penalty for affinity loss",
)
# Misc
parser.add_argument(
"-t", "--test_every", type=int, default=1000, help="Test interval"
)
parser.add_argument(
"--checkpoint_every",
type=int,
default=100,
help="Number of epochs per checkpoint",
)
parser.add_argument(
"--num_checkpoints", type=int, default=1, help="Number of checkpoints to keep"
)
parser.add_argument("--progress_bar", action="store_true", help="Show progress bar")
parser.add_argument("-g", "--gpu", type=str, default="cuda:0", help="Device name")
# ROC AUC fails when there is only one class (i.e. all poses are good poses)
# This happens when training with crystal structures only
parser.add_argument(
"--no_roc_auc",
action="store_false",
help="Disable ROC AUC (useful for crystal poses)",
dest="roc_auc",
)
parser.add_argument("-s", "--seed", type=int, default=None, help="Random seed")
parser.add_argument("--silent", action="store_true", help="No console output")
return parser.parse_args(args)
def _train_step_pose(
trainer: Engine,
batch,
model: nn.Module,
optimizer,
pose_loss: nn.Module,
clip_gradients: float,
) -> float:
"""
Training step for pose prediction.
Parameters
----------
trainer: Engine
PyTorch Ignite engine for training
batch:
Batch of data
model:
PyTorch model
optimizer:
PyTorch optimizer
pose_loss:
Loss function for pose prediction
clip_gradients:
Gradient clipping threshold
Returns
-------
float
Loss
Notes
-----
Gradients are clipped by norm and not by value.
"""
model.train()
optimizer.zero_grad()
# Data is already on the correct device thanks to the ExampleProvider
grids, labels = batch
pose_log = model(grids)
# Compute loss for pose prediction
loss = pose_loss(pose_log, labels)
loss.backward()
# TODO: Double check that gradient clipping by norm corresponds to the Caffe
# implementation
nn.utils.clip_grad_norm_(model.parameters(), clip_gradients)
optimizer.step()
return loss.item()
def _train_step_pose_and_affinity(
trainer: Engine,
batch,
model: nn.Module,
optimizer,
pose_loss: nn.Module,
affinity_loss: nn.Module,
clip_gradients: float,
) -> float:
"""
Training step for pose and affinity prediction.
Parameters
----------
trainer: Engine
PyTorch Ignite engine for training
batch:
Batch of data
model:
PyTorch model
optimizer:
PyTorch optimizer
pose_loss:
Loss function for pose prediction
affinity_loss:
Loss function for binding affinity prediction
clip_gradients:
Gradient clipping threshold
Returns
-------
float
Loss
Notes
-----
Gradients are clipped by norm and not by value.
"""
model.train()
optimizer.zero_grad()
# Data is already on the correct device thanks to the ExampleProvider
grids, labels, affinities = batch
pose_log, affinities_pred = model(grids)
# Compute combined loss for pose prediction and affinity prediction
loss = pose_loss(pose_log, labels) + affinity_loss(affinities_pred, affinities)
loss.backward()
# TODO: Double check that gradient clipping by norm corresponds to the Caffe
# implementation
nn.utils.clip_grad_norm_(model.parameters(), clip_gradients)
optimizer.step()
return loss.item()
def _setup_trainer(
model, optimizer, pose_loss, affinity_loss, clip_gradients: float
) -> Engine:
"""
Setup training engine for binding pose prediction or binding pose and affinity
prediction.
Patameters
----------
model:
Model to train
optimizer:
Optimizer
pose_loss:
Loss function for pose prediction
affinity_loss:
Loss function for affinity prediction
clip_gradients:
Gradient clipping threshold
Notes
-----
If :code:`affinity_loss is Non e`, the model return both pose and affinity
predictions, which requites a custom training step to evaluate the combine loss
function. The custom training step is defined in
:fun:`_train_step_pose_and_affinity`.
"""
if affinity_loss is not None:
# Pose prediction and binding affinity prediction
# Create engine based on custom train step
trainer = Engine(
lambda trainer, batch: _train_step_pose_and_affinity(
trainer,
batch,
model,
optimizer,
pose_loss=pose_loss,
affinity_loss=affinity_loss,
clip_gradients=clip_gradients,
)
)
else:
# Pose prediction and binding affinity prediction
# Create engine based on custom train step
trainer = Engine(
lambda trainer, batch: _train_step_pose(
trainer,
batch,
model,
optimizer,
pose_loss=pose_loss,
clip_gradients=clip_gradients,
)
)
return trainer
def _evaluation_step_pose_and_affinity(evaluator: Engine, batch, model):
"""
Evaluate model for binding pose and affinity prediction.
Parameters
----------
evaluator:
PyTorch Ignite :code:`Engine`
batch:
Batch data
model:
Model
Returns
-------
Tuple[torch.Tensor]
Class probabilities for pose prediction, affinity prediction, true pose labels
and experimental binding affinities
Notes
-----
The model returns the log softmax of the last linear layer for binding pose
prediction (log class probabilities) and the raw output of the last linear layer for
binding affinity predictions.
"""
model.eval()
with torch.no_grad():
grids, labels, affinities = batch
pose_log, affinities_pred = model(grids)
output = {
"pose_log": pose_log,
"affinities_pred": affinities_pred,
"labels": labels,
"affinities": affinities,
}
return output
def _evaluation_step_pose(evaluator: Engine, batch, model):
"""
Evaluate model for binding pose prediction only.
Parameters
----------
evaluator:
PyTorch Ignite :code:`Engine`
batch:
Batch data
model:
Model
Returns
-------
Tuple[torch.Tensor]
Class probabilities for pose prediction and true pose labels
Notes
-----
While not strictly necessary (the default PyTorch Ignite evaluator would work well
in the case of pose-prediction only), this function is used to return a dictionary
of the output with the same key used in :fun:`_evaluation_step_pose_and_affinity`.
This allows to simplify the code of the learning rate scheduler function. This
function also allows consistency in allowing the use of
:fun:`transforms.output_transform_select_pose` for both pose prediction only and
binding pose prediction with binding affinity prediction.
"""
model.eval()
with torch.no_grad():
grids, labels = batch
pose_log = model(grids)
output = {
"pose_log": pose_log,
"labels": labels,
}
return output
def _setup_evaluator(model, metrics, affinity: bool = False) -> Engine:
"""
Setup PyTorch Ignite :code:`Engine` for evaluation.
Parameters
----------
model:
PyTorch model
metrics:
Evaluation metrics
affinity: bool
Flag for affinity prediction (in addition to pose prediction)
Returns
-------
ignite.Engine
PyTorch Ignite engine for evaluation
Notes
-----
For pose prediction the model is rather standard (single outpout) and therefore
the :code:`create_supervised_evaluator()` factory function is used. For both pose
and binding affinity prediction, the custom
:code:`_evaluation_step_pose_and_affinity` is used instead.
"""
if affinity:
evaluator = Engine(
lambda evaluator, batch: _evaluation_step_pose_and_affinity(
evaluator, batch, model
)
)
else:
evaluator = Engine(
lambda evaluator, batch: _evaluation_step_pose(evaluator, batch, model)
)
# Add metrics to the evaluator engine
# Metrics need an output_tranform method in order to select the correct output
# from _evaluation_step_pose_and_affinity
for name, metric in metrics.items():
metric.attach(evaluator, name)
return evaluator
def training(args):
"""
Main function for training GNINA scoring function.
Parameters
----------
args:
Command line arguments
Notes
-----
Training might start off slow because the :code:`molgrid.ExampleProvider` is caching
the structures that are read from .gninatypes files. The training then speeds up
considerably.
"""
# Create necessary directories if not already present
os.makedirs(args.out_dir, exist_ok=True)
# Define output streams for logging
logfile = open(os.path.join(args.out_dir, "training.log"), "w")
if not args.silent:
outstreams = [sys.stdout, logfile]
else:
outstreams = [logfile]
# Print command line arguments
for outstream in outstreams:
utils.print_args(args, "--- GNINA TRAINING ---", stream=outstream)
# Set random seed for reproducibility
if args.seed is not None:
molgrid.set_random_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Set device
device = utils.set_device(args.gpu)
# Create example providers
train_example_provider = setup.setup_example_provider(
args.trainfile, args, training=True
)
if args.testfile is not None:
test_example_provider = setup.setup_example_provider(
args.testfile, args, training=False
)
# Create grid maker
grid_maker = setup.setup_grid_maker(args)
train_loader = GriddedExamplesLoader(
example_provider=train_example_provider,
grid_maker=grid_maker,
label_pos=args.label_pos,
affinity_pos=args.affinity_pos,
random_translation=args.random_translation,
random_rotation=args.random_rotation,
device=device,
)
if args.testfile is not None:
test_loader = GriddedExamplesLoader(
example_provider=test_example_provider,
grid_maker=grid_maker,
label_pos=args.label_pos,
affinity_pos=args.affinity_pos,
random_translation=args.random_translation,
random_rotation=args.random_rotation,
device=device,
)
assert test_loader.dims == train_loader.dims
affinity: bool = True if args.affinity_pos is not None else False
# Create model
# Select model based on architecture and affinity flag (pose vs affinity)
model = models_dict[(args.model, affinity)](train_loader.dims).to(device)
model.apply(weights_and_biases_init)
# Compile model into TorchScript
model = torch.jit.script(model)
optimizer = optim.SGD(
model.parameters(),
lr=args.base_lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
# Define loss functions
pose_loss = torch.jit.script(nn.NLLLoss())
affinity_loss = (
torch.jit.script(
AffinityLoss(
delta=args.delta_affinity_loss,
penalty=args.penalty_affinity_loss,
pseudo_huber=args.pseudo_huber_affinity_loss,
scale=args.scale_affinity_loss,
)
)
if affinity
else None
)
trainer = _setup_trainer(
model,
optimizer,
pose_loss=pose_loss,
affinity_loss=affinity_loss,
clip_gradients=args.clip_gradients,
)
allmetrics = metrics.setup_metrics(
affinity, pose_loss, affinity_loss, args.roc_auc, device
)
evaluator = _setup_evaluator(model, allmetrics, affinity=affinity)
@trainer.on(Events.EPOCH_COMPLETED(every=args.test_every))
def log_training_results(trainer):
evaluator.run(train_loader)
for outstream in outstreams:
utils.log_print(
evaluator.state.metrics,
title="Train Results",
epoch=trainer.state.epoch,
stream=outstream,
)
if args.lr_dynamic:
torch_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="max",
factor=args.lr_reduce,
patience=args.lr_patience,
min_lr=args.lr_min,
verbose=False,
)
# TODO: Define handle elsewhere and attach using input arguments
# TODO: Save lr history
# Event.COMPLETED since we want the full evaluation to be completed
@evaluator.on(Events.COMPLETED)
def scheduler(evaluator):
metrics = evaluator.state.metrics
loss = metrics["Loss (pose)"]
try:
loss += metrics["Loss (affinity)"]
except KeyError:
# No affinity loss
pass
torch_scheduler.step(loss)
assert len(optimizer.param_groups) == 1
for oustream in outstreams:
print(
f" Learning rate: {optimizer.param_groups[0]['lr']}",
file=oustream,
)
if args.testfile is not None:
@trainer.on(Events.EPOCH_COMPLETED(every=args.test_every))
def log_test_results(trainer):
evaluator.run(test_loader)
for outstream in outstreams:
utils.log_print(
evaluator.state.metrics,
title="Test Results",
epoch=trainer.state.epoch,
stream=outstream,
)
# TODO: Save input parameters as well
# TODO: Save best models (lower loss)
to_save = {"model": model, "optimizer": optimizer}
# Requires no checkpoint in the output directory
# Since checkpoints are not automatically removed when restarting, it would be
# dangerous to run without requiring the directory to have no previous checkpoints
checkpoint = Checkpoint(
to_save,
args.out_dir,
n_saved=args.num_checkpoints,
global_step_transform=lambda *_: trainer.state.epoch,
)
trainer.add_event_handler(
Events.EPOCH_COMPLETED(every=args.checkpoint_every), checkpoint
)
if args.progress_bar:
pbar = ProgressBar()
pbar.attach(trainer)
trainer.run(train_loader, max_epochs=args.iterations)
# Close log file
logfile.close()
if __name__ == "__main__":
args = options()
training(args)