Skip to content

Commit

Permalink
BioLib update
Browse files Browse the repository at this point in the history
  • Loading branch information
FSGade committed Apr 27, 2023
1 parent e26335a commit 1ca85d6
Showing 1 changed file with 132 additions and 18 deletions.
150 changes: 132 additions & 18 deletions src/predict_biolib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import glob
import os
import pickle
import re
import subprocess
import tempfile
Expand All @@ -40,10 +41,31 @@
from predict_webserver import (get_basename_no_ext,
get_directory_basename_dict,
predict_using_models, read_list_file,
set_struc_res_bfactor, true_if_zip)
set_struc_res_bfactor, true_if_zip,
load_gam_model, normalize_scores)


def save_clean_pdb_single_chains(
pdb_path, pdb_name, bscore, outdir, save_full_complex=False
):
"""
Function to save cleaned PDB file(s) with specified B-factor score.
Parameters
----------
pdb_path : str
The path to the PDB file.
pdb_name : str
The name of the PDB file.
bscore : 100 or None
The B-factor score to be assigned to the residues/atoms, no change with None. 100 for solved structure, none for AF2
outdir : str
The directory to save the output file(s).
save_full_complex : bool, optional
If True, the function will save the whole complex as single PDB file.
Else, individual chains are saved as separate PDB files. Default is False.
"""


def save_clean_pdb_single_chains(pdb_path, pdb_name, bscore, outdir):
class Clean_Chain(Select):
def __init__(self, score, chain=None):
self.bscore = bscore
Expand Down Expand Up @@ -108,18 +130,39 @@ def accept_atom(self, atom):
else:
break

chains = structure.get_chains()

for chain in chains:
pdb_out = f"{outdir}/{pdb_name}_{chain.get_id()}.pdb"
# Save whole complex, cleaned
if save_full_complex:
pdb_out = f"{outdir}/{pdb_name}.pdb"
io_w_no_h = PDBIO()
io_w_no_h.set_structure(structure)
with open(pdb_out, "w") as f:
print(*header, sep="\n", file=f)
io_w_no_h.save(f, Clean_Chain(bscore, chain))
io_w_no_h.save(f, Clean_Chain(bscore, chain=None))


def predict_and_save(models, dataset, pdb_dir, out_dir, verbose: int = 0) -> None:
# Save individual chains, cleaned (default)
else:
chains = structure.get_chains()

for chain in chains:
pdb_out = f"{outdir}/{pdb_name}_{chain.get_id()}.pdb"
io_w_no_h = PDBIO()
io_w_no_h.set_structure(structure)
with open(pdb_out, "w") as f:
print(*header, sep="\n", file=f)
io_w_no_h.save(f, Clean_Chain(bscore, chain=chain))


def predict_and_save(
models,
dataset,
pdb_dir,
out_dir,
gam_len_to_mean=False,
gam_surface_to_std=False,
calibrated_score_epi_threshold=0.90,
no_calibrated_normalization=False,
verbose: int = 0,
) -> None:
"""Predicts and saves CSV/PDBs with DiscoTope-3.0 scores"""

log.debug(f"Predicting PDBs ...")
Expand All @@ -142,7 +185,9 @@ def predict_and_save(models, dataset, pdb_dir, out_dir, verbose: int = 0) -> Non
if dataset[i]["X_arr"] is not False
]
)
df_all.insert(3, "DiscoTope-3.0_score", y_all)
df_all.insert(4, "DiscoTope-3.0_score", y_all)
df_all.insert(5, "calibrated_score", np.nan)
df_all.insert(6, "epitope", np.nan)

# Round numerical columns to 5 digits for nicer CSV output
num_cols = ["DiscoTope-3.0_score", "rsa"]
Expand Down Expand Up @@ -190,14 +235,30 @@ def predict_and_save(models, dataset, pdb_dir, out_dir, verbose: int = 0) -> Non
df = df_all.iloc[start:end]
start = end

# Normalize for length and surface area with Calibrated-scores
calibrated_scores = normalize_scores(df, gam_len_to_mean, gam_surface_to_std)

# Epitopes can now be set by fixed threshold, default median epitope Calibrated-score (0.90)
# Nb: All residue median 0.00, exposed 0.50, exposed epitope 0.90
df["epitope"] = calibrated_scores >= calibrated_score_epi_threshold

# Set Calibrated-scores to string for nicer CSV output
df["calibrated_score"] = pd.Series(calibrated_scores).apply(lambda x: "{:.5f}".format(x))

# Save CSV
outfile = f"{out_dir}/{_pdb}_discotope3.csv"
df.to_csv(outfile, index=False)

# Save PDB, after adding prediction scores
struc_pred = set_struc_res_bfactor(
struc, df["DiscoTope-3.0_score"].values.astype(float) * 100
)
# Save PDB with or without Calibrated-normalized scores
if no_calibrated_normalization:
struc_pred = set_struc_res_bfactor(
struc, df["DiscoTope-3.0_score"].values.astype(float) * 100
)
else:
struc_pred = set_struc_res_bfactor(
struc, df["calibrated_score"].values.astype(float) * 100
)

outfile = f"{out_dir}/{_pdb}_discotope3.pdb"
strucio.save_structure(outfile, struc_pred)

Expand Down Expand Up @@ -338,7 +399,11 @@ def fetch_pdbs_extract_single_chains(pdb_list, out_dir) -> None:
sys.exit(1)

save_clean_pdb_single_chains(
pdb_path=f"{out_dir}/temp", pdb_name=prot_id, bscore=bscore, outdir=out_dir
pdb_path=f"{out_dir}/temp",
pdb_name=prot_id,
bscore=bscore,
outdir=out_dir,
save_full_complex=args.multichain_mode,
)


Expand Down Expand Up @@ -404,6 +469,20 @@ def is_valid_path(parser, arg):
type=lambda x: is_valid_path(p, x),
)

p.add_argument(
"--calibrated_score_epi_threshold",
type=float,
help="Calibrated-score threshold for epitopes [low 0.40, moderate (0.90), higher 1.50]",
default=0.90,
)

p.add_argument(
"--no_calibrated_normalization",
action="store_true",
default=False,
help="Skip Calibrated-normalization of PDBs",
)

p.add_argument(
"--check_existing_embeddings",
default=False,
Expand All @@ -424,6 +503,13 @@ def is_valid_path(parser, arg):
default=1000,
)

p.add_argument(
"--multichain_mode",
action="store_true",
default=False,
help="Predicts entire complexes, unsupported and untested",
)

p.add_argument(
"--save_embeddings",
default=False,
Expand Down Expand Up @@ -550,6 +636,12 @@ def true_if_list(infile):
def main(args):
"""Main function"""

# Log if multichain mode is set
if args.multichain_mode:
log.info(f"Multi-chain mode set, will predict PDBs as complexes")
else:
log.info(f"Single-chain mode set, will predict PDBs as single chains")

# Error messages if invalid input
is_list_file = check_valid_input(args)

Expand Down Expand Up @@ -596,7 +688,13 @@ def main(args):
)
for f in pdb_list:
pdb_name = get_basename_no_ext(f)
save_clean_pdb_single_chains(f, pdb_name, bscore, input_chains_dir)
save_clean_pdb_single_chains(
f,
pdb_name,
bscore,
input_chains_dir,
save_full_complex=args.multichain_mode,
)

# 3. If single PDB, copy to tempdir
else:
Expand All @@ -606,7 +704,13 @@ def main(args):
log.info(
f"Single PDB file input ({pdb_name}), extracting single chains to {input_chains_dir}"
)
save_clean_pdb_single_chains(f, pdb_name, bscore, input_chains_dir)
save_clean_pdb_single_chains(
f,
pdb_name,
bscore,
input_chains_dir,
save_full_complex=args.multichain_mode,
)

# Summary statistics
chain_list = glob.glob(f"{input_chains_dir}/*.pdb")
Expand All @@ -630,12 +734,22 @@ def main(args):
# Load pre-trained XGBoost models
models = load_models(args.models_dir, num_models=100)

# Load GAMs to normalize scores by length and surface area
gam_len_to_mean = load_gam_model(f"{args.models_dir}/gam_len_to_mean.pkl")
gam_surface_to_std = load_gam_model(
f"{args.models_dir}/gam_surface_to_std.pkl"
)

# Predict and save
predict_and_save(
models,
dataset,
input_chains_dir,
out_dir=out_dir,
gam_len_to_mean=gam_len_to_mean,
gam_surface_to_std=gam_surface_to_std,
calibrated_score_epi_threshold=args.calibrated_score_epi_threshold,
no_calibrated_normalization=args.no_calibrated_normalization,
verbose=args.verbose,
)

Expand Down

0 comments on commit 1ca85d6

Please sign in to comment.