#!/usr/bin/env python import re import time import os import argparse import sys import pandas as pd from pysam import VariantFile from multiprocessing import Pool import pysam from pathlib import Path from tqdm.auto import tqdm import seaborn as sns from itertools import groupby from operator import itemgetter def parse_arg(argv): """ Function for pasing arguments """ parser = argparse.ArgumentParser( description="methphaser: phase reads based on methlytion informaiton" ) required_args = parser.add_argument_group("Required arguments") required_args.add_argument( "-ib", "--input_bam_file", type=str, help="input SNP-phased bam file", required=True, metavar="", ) required_args.add_argument( "-if", "--meth_phasing_input_folder", type=str, help="meth phasing input folder", required=True, metavar="", ) required_args.add_argument( "-ov", "--output_vcf", type=str, help="output VCF file location", required=True, metavar="", ) required_args.add_argument( "-ob", "--output_bam", type=str, help="output BAM file (without .bam suffix)", required=True, metavar="", ) # required_args.add_argument( # "-ib", # "--input_bam", # type=str, # help="SNP-phased BAM file", # required=True, # metavar="", # ) required_args.add_argument( "-vc", "--vcf_called", type=str, help="SNP-phased VCF file", required=True, metavar="", ) required_args.add_argument( "-t", "--threads", type=int, help="threads, default 1", default=1, metavar="", ) parser.add_argument( "-vt", "--vcf_truth", type=str, help="truth VCF provided by GIAB", metavar="", ) parser.add_argument( "-hs", "--high_success_rate_param", help="Enable high success rate parameter", action="store_true", ) parser.add_argument( "-mc", "--minimum_coverage", help="Minimum read number to assign blocks' relationship. default: 0. Recommanded setting: mc = coverage/(autosome-wide block number/1000), for details please see the paper", default=0, metavar="", ) parser.add_argument( "-vd", "--voting_difference", help="minimum voting difference for relationship assignment, default=0.5", default=0.5, metavar="", ) if len(argv) == 0: parser.print_help(sys.stderr) sys.exit(1) args = parser.parse_args(argv) return args def apply_flip_counter(relationship_to_block, flip_counter): if relationship_to_block == "same": return flip_counter else: return -1 * flip_counter def result_filtering(comparison_df, min_reads_num, min_variance): comparison_df[ comparison_df["same_hap_num"] + comparison_df["diff_hap_num"] <= min_reads_num ] = comparison_df[ comparison_df["same_hap_num"] + comparison_df["diff_hap_num"] <= min_reads_num ].assign( myth_phasing_relationship="cannot decide" ) comparison_df[ ( abs(comparison_df["same_hap_num"] - comparison_df["diff_hap_num"]) / (comparison_df["same_hap_num"] + comparison_df["diff_hap_num"]) ) <= min_variance ] = comparison_df[ ( abs(comparison_df["same_hap_num"] - comparison_df["diff_hap_num"]) / (comparison_df["same_hap_num"] + comparison_df["diff_hap_num"]) ) <= min_variance ].assign( myth_phasing_relationship="cannot decide" ) return comparison_df def get_altered_block_start_loc(final_blocks, original_block_start_loc): for i in final_blocks: if original_block_start_loc <= i[1] and original_block_start_loc >= i[0]: return str(i[0]) return -1 def get_all_final_blocks_dict(block_relationship_df): final_block_dict = {} for i in block_relationship_df.keys(): current_chr = i current_block_relationship_df = block_relationship_df[current_chr] current_block_relationship_df_all = current_block_relationship_df[ current_block_relationship_df.myth_phasing_relationship != "cannot decide" ] data = list(current_block_relationship_df_all.index) current_chr_index = [] for k, g in groupby(enumerate(data), lambda ix: ix[0] - ix[1]): current_chr_index.append(list(map(itemgetter(1), g))) current_block_list = [] # print(current_chr_index) for j in current_chr_index: start = eval( current_block_relationship_df_all.loc[j[0]].snp_phased_block_1 )[0] end = eval(current_block_relationship_df_all.loc[j[-1]].snp_phased_block_2)[ 1 ] current_block_list.append((start, end)) current_block_relationship_df_unknown = current_block_relationship_df[ current_block_relationship_df.myth_phasing_relationship == "cannot decide" ] for j in list(current_block_relationship_df.index): if j not in data: if j == 0: start = eval( current_block_relationship_df.loc[j].snp_phased_block_1 )[0] end = eval(current_block_relationship_df.loc[j].snp_phased_block_1)[1] current_block_list.append((start, end)) start = eval(current_block_relationship_df.loc[j].snp_phased_block_2)[0] end = eval(current_block_relationship_df.loc[j].snp_phased_block_2)[1] current_block_list.append((start, end)) final_block_dict.update({current_chr: current_block_list}) return final_block_dict def get_altered_vcf(original_vcf, output_vcf, block_relationship_dfs): final_block_dict = {} called_vcf_file = VariantFile(original_vcf) flipping_dict = {} remaining_dict = {} final_phase_block_dict = {} os.makedirs(os.path.dirname(f"./{output_vcf}"), exist_ok=True) with open(output_vcf, "w") as altered_vcf_file: altered_vcf_file.writelines(str(called_vcf_file.header)) for chrom in tqdm(block_relationship_dfs.keys()): # chrom = 'chr1' current_chrom_final_block = get_all_final_blocks_dict( block_relationship_dfs )[chrom] current_chrom_block_relationship_df = block_relationship_dfs[chrom] # print(chrom, len(current_chrom_block_relationship_df)) final_block_list = [] flipping_list = [] block_num = 0 # print(current_chrom_block_relationship_df) while block_num < len(current_chrom_block_relationship_df) - 1: current_block_row = current_chrom_block_relationship_df.iloc[block_num] if ( current_block_row.myth_phasing_relationship == "cannot decide" ): # if the relationship is cannot decide then skip block_num += 1 else: current_block_start = ( current_block_row.snp_phased_block_1 ) # collect the assignment of the first block block_start = eval(current_block_start)[0] flipping_list.append((current_block_start, 1)) flip_counter = 1 # never flip the first block current_relationship = ( current_block_row.myth_phasing_relationship ) # the relationship between this and the next block flip_counter = apply_flip_counter( current_relationship, flip_counter ) # change flip couner into first block and next block current_block_num = block_num block_num += 1 # if block_num >= len(current_chrom_block_relationship_df) - 1: # break # print(flipping_list, block_num) while ( (block_num < len(current_chrom_block_relationship_df)) and (current_chrom_block_relationship_df.iloc[ block_num ].myth_phasing_relationship != "cannot decide") ): # print(flipping_list, block_num) flipping_list.append( ( current_chrom_block_relationship_df.iloc[ block_num ].snp_phased_block_1, flip_counter, ) ) # add flip counter flip_counter = apply_flip_counter( current_chrom_block_relationship_df.iloc[ block_num ].myth_phasing_relationship, flip_counter, ) # change flip counter based on previous flip counter, snp block 1 and 2 block_num += 1 # block_num -= 1 flipping_list.append( ( current_chrom_block_relationship_df.iloc[ block_num - 1 ].snp_phased_block_2, flip_counter, ) ) current_block_length = block_num - current_block_num block_end = eval( current_chrom_block_relationship_df.iloc[ block_num - 1 ].snp_phased_block_2 )[1] final_block_list.append((block_start, block_end)) # print(flipping_list) final_block_dict.update({chrom: final_block_list}) flipping_dict.update({chrom: flipping_list}) unphased_list = [] for index, i in enumerate( current_chrom_block_relationship_df.snp_phased_block_1[:-1] ): unphased_list.append( ( eval(i)[1], eval( current_chrom_block_relationship_df.snp_phased_block_1[ index + 1 ] )[0], ) ) for i in unphased_list: called_vcf = called_vcf_file.fetch(chrom, i[0], i[1]) for rec in called_vcf: rec = str(rec) altered_vcf_file.writelines(rec) for i in flipping_list: current_block = eval(i[0]) called_vcf = called_vcf_file.fetch( chrom, current_block[0], current_block[1] ) for rec in called_vcf: rec = str(rec) if i[1] == 1: altered_vcf_file.writelines(rec) else: if "1|0" in rec: if "PS" in rec.split("\t")[-2]: split_rec = rec.split("\t") ps_tag_location = split_rec[-2].split(":").index("PS") start_loc = split_rec[-1].split(":")[ps_tag_location] # print(split_rec, start_loc) split_rec[-1] = split_rec[-1].replace(start_loc, get_altered_block_start_loc(current_chrom_final_block, int(start_loc))) # type: ignore rec = "\t".join(split_rec) rec = f"{rec}\n" altered_vcf_file.writelines(rec.replace("1|0", "0|1")) elif "0|1" in rec: if "PS" in rec.split("\t")[-2]: split_rec = rec.split("\t") ps_tag_location = split_rec[-2].split(":").index("PS") start_loc = split_rec[-1].split(":")[ps_tag_location] # print(split_rec, start_loc) split_rec[-1] = split_rec[-1].replace(start_loc, get_altered_block_start_loc(current_chrom_final_block, int(start_loc))) # type: ignore rec = "\t".join(split_rec) rec = f"{rec}\n" altered_vcf_file.writelines(rec.replace("0|1", "1|0")) else: altered_vcf_file.writelines(rec) flpl = [x[0] for x in flipping_list] remaining_list = [ x for x in list(current_chrom_block_relationship_df.snp_phased_block_1) if x not in flpl ] if ( current_chrom_block_relationship_df.iloc[-1].snp_phased_block_2 not in flpl ): remaining_list.append( current_chrom_block_relationship_df.iloc[-1].snp_phased_block_2 ) remaining_dict.update({chrom: remaining_list}) for i in remaining_list: current_block = eval(i) called_vcf = called_vcf_file.fetch( chrom, current_block[0], current_block[1] ) for rec in called_vcf: rec = str(rec) altered_vcf_file.writelines(rec) # print(final_block_dict, remaining_dict, flipping_dict, unphased_list) return final_block_dict, remaining_dict, flipping_dict def get_altered_bam(flipping_dict, remaining_dict, block_relationship_dfs, input_bam_file, original_bam_file, modified_bam_file, assignment_path, chrom): input_bam = pysam.AlignmentFile( input_bam_file, "rb") original_bam = pysam.AlignmentFile( original_bam_file, "rb") modified_bam = pysam.AlignmentFile( modified_bam_file, "wb", template=original_bam) sorted_read_assignment_files = sorted(os.listdir( assignment_path), key=lambda x: int(x.split('_')[0])) flipping_dict_chom = dict(flipping_dict[chrom]) # snp block that does not have relationship assignment are not affected. snp_block_flipping_chom = dict( zip(remaining_dict[chrom], [1]*len(remaining_dict[chrom]))) snp_block_flipping_chom.update(flipping_dict_chom) current_all_unphased_reads_list_new = [] current_all_phased_reads_list_new = [] for index, i in tqdm(enumerate(sorted_read_assignment_files)): # the vital part is to deal with the overlapped reads splitted_block_name = re.split("\\.|_", i) current_block_reads_f_path = os.path.join(assignment_path, i) if index == len(sorted_read_assignment_files) - 1: # if the last relationship block current_snp_block = block_relationship_dfs[chrom].snp_phased_block_2[index-1] else: current_snp_block = block_relationship_dfs[chrom].snp_phased_block_1[index] current_block_reads_df = pd.read_csv(current_block_reads_f_path) read_to_hp_dict = dict( zip(current_block_reads_df['read_id'], current_block_reads_df['haplotype'])) current_flip_flag = snp_block_flipping_chom[current_snp_block] current_snp_block_reads = input_bam.fetch(chrom, eval(current_snp_block)[0], eval(current_snp_block)[1]) # fetch reads with this block current_all_phased_reads_list = [] if current_flip_flag == 1: for reads in current_snp_block_reads: if (reads.query_name not in current_all_phased_reads_list_new) and (reads.has_tag("HP")): current_all_phased_reads_list.append(reads.query_name) modified_bam.write(reads) elif current_flip_flag == -1: for reads in current_snp_block_reads: if (reads.query_name not in current_all_phased_reads_list_new) and (reads.has_tag("HP")): # switch already called ones current_all_phased_reads_list.append(reads.query_name) if reads.get_tag('HP') == 1: reads.set_tag(tag='HP', value=2, value_type='i') modified_bam.write(reads) elif reads.get_tag('HP') == 2: reads.set_tag(tag='HP', value=1, value_type='i') modified_bam.write(reads) current_all_phased_reads_list_new = current_all_phased_reads_list current_extended_block_reads = input_bam.fetch(chrom, int(splitted_block_name[1]), int(splitted_block_name[2])) # fetch reads with this block current_all_unphased_reads_list = [] if current_flip_flag == 1: for reads in current_extended_block_reads: # if the read is overlapped with previous block's reads, it should not be re-assigned if (reads.query_name not in current_all_unphased_reads_list_new) and (not reads.has_tag("HP")): # previously unhaplotagged current_all_unphased_reads_list.append(reads.query_name) if reads.query_name in list(current_block_reads_df.read_id): if read_to_hp_dict[reads.query_name] == 1: reads.set_tag(tag='HP', value=1, value_type='i') modified_bam.write(reads) elif read_to_hp_dict[reads.query_name] == 2: reads.set_tag(tag='HP', value=2, value_type='i') modified_bam.write(reads) else: # still unhaplotagged modified_bam.write(reads) elif current_flip_flag == -1: for reads in current_extended_block_reads: # if the read is overlapped with previous reads, it should not be re-assigned if (reads.query_name not in current_all_unphased_reads_list_new) and (not reads.has_tag("HP")): # previously unhaplotagged current_all_unphased_reads_list.append(reads.query_name) if reads.query_name in list(current_block_reads_df.read_id): if read_to_hp_dict[reads.query_name] == 1: # switch reads.set_tag(tag='HP', value=2, value_type='i') modified_bam.write(reads) elif read_to_hp_dict[reads.query_name] == 2: reads.set_tag(tag='HP', value=1, value_type='i') modified_bam.write(reads) else: # still unhaplotagged modified_bam.write(reads) # all previously unhaplotagged reads were considered before current_all_unphased_reads_list_new = current_all_unphased_reads_list modified_bam.close() def get_precision_recall(input_folder, min_required_read, min_diff_perc): f_folder = input_folder relationship_df_list = [] crr_num = 0 err_num = 0 total_len = 0 for chr_name in os.listdir(f_folder): # chr_name = 'chr1' chr_folder_path = os.path.join(f_folder, chr_name) for csv_file in os.listdir(chr_folder_path): csv_file_path = os.path.join(chr_folder_path, csv_file) i = pd.read_csv(csv_file_path, index_col=0)[1:].reset_index() i = result_filtering(i, min_required_read, min_diff_perc) relationship_df_list.append(i) crr_num += len( i[ ( (i.vcf_file_relationship == "same") & (i.myth_phasing_relationship == "same") ) | ( (i.vcf_file_relationship == "not same") & (i.myth_phasing_relationship == "not same") ) ] ) err_num += len( i[ ( (i.vcf_file_relationship == "same") & (i.myth_phasing_relationship == "not same") ) | ( (i.vcf_file_relationship == "not same") & (i.myth_phasing_relationship == "same") ) ] ) total_len += len(i[i.vcf_file_relationship != "cannot decide"]) return (crr_num / (crr_num + err_num), crr_num, total_len, crr_num / total_len) def per_chromosome_precision_recall( input_folder, min_required_read, min_diff_perc, figure_output ): genome_crr_num = 0 genome_err_num = 0 genome_all_num = 0 f_folder = input_folder relationship_df_list = [] chr_result_df = pd.DataFrame(columns=["chr", "accuracy/success rate", "a/s"]) for chrom in range(1, 23): chr_name = f"chr{chrom}" crr_num = 0 err_num = 0 total_len = 0 chr_folder_path = os.path.join(f_folder, chr_name) for csv_file in os.listdir(chr_folder_path): csv_file_path = os.path.join(chr_folder_path, csv_file) i = pd.read_csv(csv_file_path, index_col=0)[1:].reset_index() i = result_filtering(i, min_required_read, min_diff_perc) relationship_df_list.append(i) err_df = i[ ( (i.vcf_file_relationship == "same") & (i.myth_phasing_relationship == "not same") ) | ( (i.vcf_file_relationship == "not same") & (i.myth_phasing_relationship == "same") ) ] crr_num += len( i[ ( (i.vcf_file_relationship == "same") & (i.myth_phasing_relationship == "same") ) | ( (i.vcf_file_relationship == "not same") & (i.myth_phasing_relationship == "not same") ) ] ) err_num += len(err_df) total_len += len(i[i.vcf_file_relationship != "cannot decide"]) chr_num = int(chr_name[3:]) genome_crr_num += crr_num genome_err_num += err_num genome_all_num += total_len chr_result_df.loc[len(chr_result_df.index)] = [chr_num, 0 if crr_num + err_num == 0 else crr_num / (crr_num + err_num), "accuracy"] # type: ignore chr_result_df.loc[len(chr_result_df.index)] = [chr_num, 0 if total_len == 0 else crr_num / total_len, "success rate"] # type: ignore chr_result_df = chr_result_df.sort_values(by="chr") sns_bar = sns.barplot( data=chr_result_df, x="chr", y="accuracy/success rate", hue="a/s" ) fig = sns_bar.get_figure() fig.savefig(figure_output) return (genome_crr_num, genome_err_num, genome_all_num) # def get_whatshap_phase_block_dict(gtf_file): # phased_block_chr_whatshap_dict = {} # for i in chroms: # phased_block_df = pd.read_csv(gtf_file, header=None, sep="\t", # names=["chr","phasing","ex/intron","start","end","1","strand","2","info",],) # phased_block_df_current_chrom = phased_block_df[phased_block_df.chr == i] # phased_block_chr_whatshap_dict.update({i: list(zip(phased_block_df_current_chrom["start"], phased_block_df_current_chrom["end"]))}) # return phased_block_chr_whatshap_dict def get_n50(final_block_dict): chrom_n50 = {} for chrom in final_block_dict.keys(): block_len_list = [] for i in final_block_dict[chrom]: block_len_list.append(i[1] - i[0]) sum_block_len = 0 for i in sorted(block_len_list, reverse=False): sum_block_len += i if sum_block_len > sum(block_len_list) / 2: chrom_n50.update({chrom: i}) break return chrom_n50 def get_block_relationships(input_folder, min_required_read=0, min_diff_perc=0): f_folder = input_folder relationship_df_by_chr = {} crr_num = 0 err_num = 0 total_len = 0 for chr_name in [x for x in os.listdir(input_folder) if "_" not in x]: relationship_df = pd.DataFrame() chr_folder_path = os.path.join(f_folder, chr_name) for csv_file in sorted( os.listdir(chr_folder_path), key=lambda x: int(x.split("_")[0]) ): csv_file_path = os.path.join(chr_folder_path, csv_file) i = pd.read_csv(csv_file_path, index_col=0)[1:].reset_index() i = result_filtering(i, min_required_read, min_diff_perc) relationship_df = pd.concat([relationship_df, i], ignore_index=True) relationship_df_by_chr.update({chr_name: relationship_df}) return relationship_df_by_chr def main(argv): """ This program does several things: 1. Output VCF file 2. Output BAM file 3. Output N50 4. If truth SNP VCF provided: a. output accuracy b. output success rate c. output a figure with both """ args = parse_arg(argv) threads = args.threads bam_file = args.input_bam_file # original_bam = args.original_bam output_bam = args.output_bam output_vcf = args.output_vcf methphsing_output_folder = args.meth_phasing_input_folder high_sr_parameter = args.high_success_rate_param # high_sr_parameter = True coverage = 0 # figure_output_path = args.figure_output_path # phsed_block_file = "" vcf_truth = args.vcf_truth vcf_called = args.vcf_called min_coverage = args.minimum_coverage votting_difference = args.voting_difference if high_sr_parameter: block_relationship_dfs = get_block_relationships( methphsing_output_folder, min_required_read=0, min_diff_perc=0 ) else: block_relationship_dfs = get_block_relationships(methphsing_output_folder, min_required_read=min_coverage, min_diff_perc=votting_difference) # type: ignore # output VCF file # print((block_relationship_dfs['chr1'])) final_block_dict, remaining_dict, flipping_dict = get_altered_vcf( vcf_called, output_vcf, block_relationship_dfs ) # print(remaining_dict['chr1']) # output BAM file interval_list = [] for i in block_relationship_dfs.keys(): chrom_assignment_path = os.path.join( f"{methphsing_output_folder}", f"{i}_read_assignment" ) chrom_output_bam = f"{output_bam}.{i}.methtagged.bam" interval_list.append( ( flipping_dict, remaining_dict, block_relationship_dfs, bam_file, bam_file, chrom_output_bam, chrom_assignment_path, i, ) ) # get_altered_bam() with Pool(threads) as pool: L = pool.starmap(get_altered_bam, interval_list) # # output ACC, SR, figure if truth is provided. # if high_sr_parameter: # acc_sr = per_chromosome_precision_recall( # methphsing_output_folder, 0, 0, figure_output_path # ) # elif high_acc_parameter: # acc_sr = per_chromosome_precision_recall( # methphsing_output_folder, coverage / 4, 0.5, figure_output_path # ) # else: # acc_sr = [] # acc = acc_sr[0] / (acc_sr[1] + acc_sr[0]) # sr = acc_sr[0] / acc_sr[2] # print(acc, sr) # return if __name__ == "__main__": main(sys.argv[1:])