Skip to content

Commit

Permalink
Merge branch 'duplex_file_handle' into 'master'
Browse files Browse the repository at this point in the history
Duplex mod calling

See merge request algorithm/remora!276
  • Loading branch information
marcus1487 committed Nov 15, 2023
2 parents 6eec805 + 4402ab4 commit 5a8e020
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 71 deletions.
134 changes: 69 additions & 65 deletions src/remora/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from copy import copy
import array as pyarray
from pathlib import Path
from itertools import chain
from threading import Thread
from itertools import chain, islice
from collections import defaultdict

import pysam
Expand Down Expand Up @@ -708,20 +708,30 @@ def check_simplex_alignments(*, simplex_index, duplex_index, pairs):
all_paired_read_ids = set(chain(*pairs))
simplex_read_ids = set(simplex_index.read_ids)
duplex_read_ids = set(duplex_index.read_ids)
count_paired_simplex_alignments = sum(
[1 for read_id in all_paired_read_ids if read_id in simplex_read_ids]
num_paired_simplex_alignments = len(
all_paired_read_ids.intersection(simplex_read_ids)
)
if count_paired_simplex_alignments == 0:
LOGGER.debug(
f"Found {num_paired_simplex_alignments} simplex simplex alignments in "
f"a pair out of {len(simplex_read_ids)} total simplex reads."
)
if num_paired_simplex_alignments == 0:
raise ValueError("zero simplex alignments found")

# valid meaning we have all the parts to perform inference
valid_read_pairs = [
(t, c)
for t, c in pairs
if all(read_id in simplex_read_ids for read_id in (t, c))
if t in simplex_read_ids
and c in simplex_read_ids
and t in duplex_read_ids
]
return valid_read_pairs, len(valid_read_pairs)
num_valid_read_pairs = len(valid_read_pairs)
LOGGER.debug(
f"Found {num_valid_read_pairs} valid reads out of {len(pairs)} "
"total pairs"
)
return valid_read_pairs, num_valid_read_pairs


def prep_duplex_read_builder(simplex_index, pod5_path):
Expand All @@ -743,6 +753,40 @@ def iter_duplexed_io_reads(read_id_pair, builder):
return builder.make_read_pair(read_id_pair)


def make_duplex_reads(read_pair_result, duplex_index):
read_pair, err = read_pair_result
if err is not None or read_pair is None:
return read_pair, err
template, complement = read_pair
if template.read_id not in duplex_index:
return read_pair, "duplex BAM record not found for read_id"
for bam_record in duplex_index.get_alignments(template.read_id):
duplex_read = DuplexRead.from_reads_and_alignment(
template_read=template,
complement_read=complement,
duplex_alignment=bam_record,
)
# TODO do we want to return all the duplex mappings?
return duplex_read, None


def add_mod_mappings_to_alignment(duplex_read_result, caller):
duplex_read, err = duplex_read_result
if err is not None:
return None, err
mod_tags = caller.call_duplex_read_mods(duplex_read)
mod_tags = list(mod_tags)
record = duplex_read.duplex_alignment
record = copy(record)
record["tags"] = [
tag
for tag in record["tags"]
if not (tag.startswith("MM") or tag.startswith("ML"))
]
record["tags"].extend(mod_tags)
return record, None


def infer_duplex(
*,
simplex_pod5_path: str,
Expand All @@ -759,13 +803,14 @@ def infer_duplex(
skip_non_primary=True,
duplex_deliminator=";",
):
LOGGER.info("Indexing Duplex BAM")
duplex_bam_index = ReadIndexedBam(
duplex_bam_path,
skip_non_primary=skip_non_primary,
req_tags=set(),
read_id_converter=lambda k: k.split(duplex_deliminator)[0],
)

LOGGER.info("Indexing Simplex BAM")
simplex_bam_index = ReadIndexedBam(
simplex_bam_path, skip_non_primary=True, req_tags={"mv"}
)
Expand All @@ -776,84 +821,40 @@ def infer_duplex(
duplex_index=duplex_bam_index,
pairs=pairs,
)
if num_reads is not None:
num_reads = min(num_valid_reads, num_reads)

# source of pipeline, template, complement read ID pairs
# produces template, complement read id tuples
def iter_pairs(pairs, num_reads):
for pair_num, pair in enumerate(pairs):
if num_reads is not None and pair_num >= num_reads:
return
yield pair

read_id_pairs = BackgroundIter(
iter_pairs,
kwargs=dict(pairs=valid_pairs, num_reads=num_reads),
use_process=False,
num_reads = (
num_valid_reads
if num_reads is None
else min(num_valid_reads, num_reads)
)

# consumes: tuple of template, complement read Ids
# prep: open resources for Pod5 and simplex BAM
# produces: (io.Read, io.Read), str
io_read_pairs_results = MultitaskMap(
iter_duplexed_io_reads,
read_id_pairs,
islice(valid_pairs, num_reads),
prep_func=prep_duplex_read_builder,
args=(simplex_bam_index, simplex_pod5_path),
name="BuildDuplexedIoReads",
q_maxsize=100,
num_workers=num_extract_alignment_threads,
use_process=False,
use_mp_queue=False,
use_process=True,
use_mp_queue=True,
)

def make_duplex_reads(read_pair_result, duplex_index, bam_file_handle):
read_pair, err = read_pair_result
if err is not None or read_pair is None:
return None, err
template, complement = read_pair
if template.read_id not in duplex_index:
return None, "duplex BAM record not found for read_id"
for pointer in duplex_index[template.read_id]:
bam_file_handle.seek(pointer)
bam_record = next(bam_file_handle)
duplex_read = DuplexRead.from_reads_and_alignment(
template_read=template,
complement_read=complement,
duplex_alignment=bam_record,
)

return duplex_read, None

# consumes: tuple of io.Reads (template, complement)
# produces: (DuplexRead, str), for inference by the model
duplex_aln = pysam.AlignmentFile(duplex_bam_path, "rb", check_sq=False)
duplex_reads = MultitaskMap(
make_duplex_reads,
io_read_pairs_results,
num_workers=num_duplex_prep_workers,
args=(duplex_bam_index, duplex_aln),
args=(duplex_bam_index,),
name="MakeDuplexReads",
use_process=False,
use_mp_queue=False,
q_maxsize=100,
use_process=True,
use_mp_queue=True,
)

def add_mod_mappings_to_alignment(duplex_read_result, caller):
duplex_read, err = duplex_read_result
if err is not None:
return None, err
mod_tags = caller.call_duplex_read_mods(duplex_read)
mod_tags = list(mod_tags)
record = duplex_read.duplex_alignment
record = copy(record)
record["tags"] = [
tag
for tag in record["tags"]
if not (tag.startswith("MM") or tag.startswith("ML"))
]
record["tags"].extend(mod_tags)
return record, None

duplex_caller = DuplexReadModCaller(model, model_metadata)

# consumes: Result[DuplexReads, str]
Expand All @@ -865,6 +866,7 @@ def add_mod_mappings_to_alignment(duplex_read_result, caller):
num_workers=num_infer_threads,
args=(duplex_caller,),
name="InferMods",
q_maxsize=100,
use_process=False,
use_mp_queue=False,
)
Expand All @@ -875,8 +877,11 @@ def add_mod_mappings_to_alignment(duplex_read_result, caller):
with pysam.AlignmentFile(out_bam, "wb", template=in_bam) as out_bam:
for mod_read_mapping, err in tqdm(
alignment_records_with_mod_tags,
dynamic_ncols=True,
smoothing=0,
total=num_reads,
dynamic_ncols=True,
unit=" Duplex Reads",
desc="Inferring duplex mods",
disable=os.environ.get("LOG_SAFE", False),
):
if err is not None:
Expand All @@ -888,7 +893,6 @@ def add_mod_mappings_to_alignment(duplex_read_result, caller):
)
)
pysam.set_verbosity(pysam_save)
duplex_aln.close()

if len(errs) > 0:
err_types = sorted([(num, err) for err, num in errs.items()])[::-1]
Expand Down
12 changes: 6 additions & 6 deletions src/remora/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,10 @@ def get_alignments(self, read_id):
self.bam_fh.seek(read_ptr)
try:
bam_read = next(self.bam_fh)
except OSError:
except OSError as e:
LOGGER.debug(
f"Could not extract {read_id} from {self.bam_path} "
f"at {read_ptr}"
f"at {read_ptr}\nFULL_ERROR: {e}"
)
raise RemoraError(
"Could not extract BAM read. Ensure BAM file object was "
Expand Down Expand Up @@ -1900,11 +1900,14 @@ def add_alignment(
except KeyError:
self.num_trimmed = 0

self.seq = alignment_record.query_sequence
if alignment_record.is_reverse:
self.seq = util.revcomp(self.seq)
try:
self.query_to_signal, self.mv_table, self.stride = parse_move_tag(
tags["mv"],
sig_len=self.dacs.size,
seq_len=len(alignment_record.query_sequence),
seq_len=len(self.seq),
reverse_signal=reverse_signal,
)
except KeyError:
Expand All @@ -1922,9 +1925,6 @@ def add_alignment(
) - self.shift_dacs_to_pa
self.scale_dacs_to_norm = self.scale_pa_to_norm / self.scale_dacs_to_pa

self.seq = alignment_record.query_sequence
if alignment_record.is_reverse:
self.seq = util.revcomp(self.seq)
if not parse_ref_align or alignment_record.is_unmapped:
return

Expand Down

0 comments on commit 5a8e020

Please sign in to comment.