diff --git a/phylofisher/forest.py b/phylofisher/forest.py index 0a5c8b7..56b51b5 100755 --- a/phylofisher/forest.py +++ b/phylofisher/forest.py @@ -15,53 +15,119 @@ from matplotlib.backends.backend_pdf import PdfPages from phylofisher import help_formatter +from phylofisher.db_map import database, Taxonomies, Metadata plt.style.use('ggplot') def configure_colors(): ''' - Configure colors for taxonomic groups from color_conf file. + Configure colors for taxonomic groups from database. :return: dictionary with taxonomic groups as keys and colors as values :rtype: dict ''' color_dict = dict() - with open(color_conf, 'r') as infile: - infile.readline() - for line in infile: - line = line.strip() - tax, color = line.split('\t') - color_dict[tax] = color + + # If we're in local run mode and have a color_conf file, use it + if 'color_conf' in globals() and os.path.exists(color_conf): + with open(color_conf, 'r') as infile: + infile.readline() + for line in infile: + line = line.strip() + tax, color = line.split('\t') + color_dict[tax] = color + else: + # Use database colors + db_query = Taxonomies.select(Taxonomies.taxonomy, Taxonomies.color) + for q in db_query: + if q.color: + color_dict[q.taxonomy] = q.color + else: + color_dict[q.taxonomy] = 'white' return color_dict -def parse_metadata(metadata, input_metadata=None): +def parse_metadata(database_path, input_metadata=None): ''' - Parse metadata from dataset and input_metadata (if provided) + Parse metadata from database and input_metadata (if provided) - :param metadata: path to the metadata file - :type metadata: str - :param input_metadata: if input is input metadata instead of database metadata, defaults to None - :type input_metadata: bool, optional + :param database_path: path to the database file + :type database_path: str + :param input_metadata: path to input metadata file, defaults to None + :type input_metadata: str, optional + :return: tuple of metadata dictionary and color dictionary + :rtype: tuple + ''' + # Parse database metadata + metadata_comb = {} + + + database.init(database_path) + database.connect() + + color_dict = configure_colors() + + db_query = Metadata.select(Metadata.short_name, Metadata.long_name, Metadata.higher_taxonomy, Metadata.lower_taxonomy) + for q in db_query: + tax = q.short_name + full = q.long_name + group = Taxonomies.get(Taxonomies.id == q.higher_taxonomy).taxonomy + sub_tax = Taxonomies.get(Taxonomies.id == q.lower_taxonomy).taxonomy + + if group not in color_dict or color_dict[group].lower() in ['x', 'xx']: + color_dict[group] = 'white' + metadata_comb[tax] = {'Higher Taxonomy': group, 'col': color_dict[group], 'full': full, + 'Lower Taxonomy': sub_tax} + + database.close() + + + # Parse input metadata if provided + if input_metadata and os.path.exists(input_metadata): + for line in open(input_metadata): + if "FILE_NAME" not in line: + metadata_input = line.split('\t') + tax = metadata_input[2].strip().split('_')[0] + group = metadata_input[3].strip() + full = metadata_input[6].strip() + sub_tax = metadata_input[4] + metadata_comb[tax] = {'Higher Taxonomy': group, 'col': "white", 'full': full, 'Lower Taxonomy': sub_tax} + + return metadata_comb, color_dict + + +def parse_metadata_tsv(metadata_file, input_metadata=None): + ''' + Parse metadata from TSV files (for backward compatibility with local runs) + + :param metadata_file: path to the metadata TSV file + :type metadata_file: str + :param input_metadata: path to input metadata file, defaults to None + :type input_metadata: str, optional :return: tuple of metadata dictionary and color dictionary :rtype: tuple ''' color_dict = configure_colors() metadata_comb = {} - for line_ in open(metadata): - if 'Full Name' not in line_: - sline = line_.split('\t') - tax = sline[0].strip() - group = sline[2].strip() - sub_tax = sline[3] - full = sline[1].strip() - if group not in color_dict or color_dict[group].lower() in ['x', 'xx']: - color_dict[group] = 'white' - metadata_comb[tax] = {'Higher Taxonomy': group, 'col': color_dict[group], 'full': full, - 'Lower Taxonomy': sub_tax} - if input_metadata: + + # Parse database metadata from TSV + if os.path.exists(metadata_file): + for line_ in open(metadata_file): + if 'Full Name' not in line_: + sline = line_.split('\t') + tax = sline[0].strip() + group = sline[2].strip() + sub_tax = sline[3] + full = sline[1].strip() + if group not in color_dict or color_dict[group].lower() in ['x', 'xx']: + color_dict[group] = 'white' + metadata_comb[tax] = {'Higher Taxonomy': group, 'col': color_dict[group], 'full': full, + 'Lower Taxonomy': sub_tax} + + # Parse input metadata if provided + if input_metadata and os.path.exists(input_metadata): for line in open(input_metadata): if "FILE_NAME" not in line: metadata_input = line.split('\t') @@ -70,6 +136,7 @@ def parse_metadata(metadata, input_metadata=None): full = metadata_input[6].strip() sub_tax = metadata_input[4] metadata_comb[tax] = {'Higher Taxonomy': group, 'col': "white", 'full': full, 'Lower Taxonomy': sub_tax} + return metadata_comb, color_dict @@ -82,7 +149,7 @@ def suspicious_clades(tree): :return: tuple of tree name and list of suspicious clades :rtype: tuple ''' - t = Tree(tree) + t = Tree(tree, format=1) # midpoint rooted tree R = t.get_midpoint_outgroup() t.set_outgroup(R) @@ -91,7 +158,17 @@ def suspicious_clades(tree): for node in t.traverse('preorder'): if (node.is_root() is False) and (node.is_leaf() is False): # report only clades which encompass less than a half of all oranisms - if node.support >= 70 and (len(node) < (len(t) - len(node))): + try: + ufboot = float(node.name.split('/')[0]) + sh_alrt = float(node.name.split('/')[1]) + except IndexError: + ufboot = 0 + sh_alrt = 0 + except ValueError: + ufboot = 0 + sh_alrt = 0 + + if ufboot >= 95 and sh_alrt >= 80 and (len(node) < (len(t) - len(node))): clade = node.get_leaf_names() if len(clade) > 1: # do we need this statement? supported_clades.append(clade) @@ -120,7 +197,7 @@ def get_best_candidates(tree_file): :return: set of best candidate sequences :rtype: set ''' - t = Tree(tree_file) + t = Tree(tree_file, format=1) top_rank = defaultdict(dict) for node in t.traverse('preorder'): if node.is_leaf(): @@ -210,7 +287,7 @@ def collect_contaminants(tree_file, cont_dict): :return: set of proven contaminants, set of proven contamination (same names as in csv result tables) :rtype: tuple(set, set) ''' - t = Tree(tree_file) + t = Tree(tree_file, format=1) R = t.get_midpoint_outgroup() t.set_outgroup(R) cont_table_names = set() @@ -300,7 +377,7 @@ def tree_to_tsvg(tree_file, contaminants=None, backpropagation=None): table = open(f"{output_folder}/{name_.split('_')[0]}.tsv", 'r') top_ranked = get_best_candidates(tree_file) - t = Tree(tree_file) + t = Tree(tree_file, format=1) ts = TreeStyle() R = t.get_midpoint_outgroup() t.set_outgroup(R) @@ -482,8 +559,18 @@ def format_nodes(node, node_style, sus_clades, t): :return: tuple of TextFace object and updated number of suspicious clades :rtype: tuple ''' - supp = TextFace(f'{int(node.support)}', fsize=8) - if node.support >= 70: + try: + ufboot = float(node.name.split('/')[0]) + sh_alrt = float(node.name.split('/')[1]) + except IndexError: + ufboot = 0 + sh_alrt = 0 + except ValueError: + ufboot = 0 + sh_alrt = 0 + + supp = TextFace(f'{int(ufboot)}/{int(sh_alrt)}', fsize=8) + if ufboot >= 80 and sh_alrt >= 95: supp.bold = True taxons = set() orgs = node.get_leaf_names() @@ -699,6 +786,10 @@ def backpropagate_contamination(tree_file, cont_names): args.metadata = f'{args.input}/metadata.tsv' args.input_metadata = f'{args.input}/input_metadata.tsv' color_conf = f'{args.input}/tree_colors.tsv' + + # For local runs, we might need to parse TSV files if database is not available + # This maintains backward compatibility + args.database_path = None # Indicates to use TSV files args.input = f'{args.input}/trees' @@ -707,17 +798,24 @@ def backpropagate_contamination(tree_file, cont_names): config = configparser.ConfigParser() config.read('config.ini') dfo = str(Path(config['PATHS']['database_folder']).resolve()) - args.metadata = str(os.path.join(dfo, 'metadata.tsv')) - color_conf = str(Path(config['PATHS']['color_conf']).resolve()) + args.database_path = str(os.path.join(dfo, 'phylofisher.db')) args.input_metadata = str(os.path.abspath(config['PATHS']['input_file'])) if not args.backpropagate: os.mkdir(output_folder) - trees = glob.glob(f"{trees_folder}/*.raxml.support") + trees = glob.glob(f"{trees_folder}/*.treefile") number_of_genes = len(trees) - metadata, tax_col = parse_metadata(args.metadata, args.input_metadata) + + # Parse metadata based on whether we're using database or local TSV files + if args.local_run and hasattr(args, 'metadata') and os.path.exists(args.metadata): + # For local runs with TSV files, use a modified parse function + metadata, tax_col = parse_metadata_tsv(args.metadata, args.input_metadata) + else: + # Use database + metadata, tax_col = parse_metadata(args.database_path, args.input_metadata) + threads = args.threads if not args.backpropagate: diff --git a/phylofisher/sgt_constructor.py b/phylofisher/sgt_constructor.py index f472316..f2c4ffc 100755 --- a/phylofisher/sgt_constructor.py +++ b/phylofisher/sgt_constructor.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import configparser import os -import shutil +import sys import subprocess import textwrap from pathlib import Path @@ -23,15 +23,15 @@ def get_genes(length_filter): else: ret = [] - len_filt_bmge_dir = f'{args.output}/length_filtration/bmge' - bmge_out_files = [file for file in os.listdir(len_filt_bmge_dir) if file.endswith('.bmge')] + len_filt_dir = f'{args.output}/length_filtered' + bmge_out_files = [file for file in os.listdir(len_filt_dir) if file.endswith('.length_filtered')] for bmge_out_file in bmge_out_files: - with open(f'{len_filt_bmge_dir}/{bmge_out_file}', 'r') as infile: + with open(f'{len_filt_dir}/{bmge_out_file}', 'r') as infile: line = infile.readline() if line == '': pass else: - ret.append(bmge_out_file.split('.bmge')[0]) + ret.append(bmge_out_file.split('.length_filtered')[0]) return ret @@ -50,7 +50,8 @@ def make_config(length_filter): f'trees_only={args.trees_only}', f'no_trees={args.no_trees}', f'database={args.database}', - f'input_metadata={args.input_metadata}' + f'input_metadata={args.input_metadata}', + f'threads_per_job={args.threads_per_job}' ] return ' '.join(ret) @@ -62,7 +63,7 @@ def get_output_files(length_filter): if length_filter: for gene in get_genes(length_filter): - ret.append(f'{args.output}/length_filtration/bmge/{gene}.bmge') + ret.append(f'{args.output}/length_filtered/{gene}.length_filtered') else: if args.no_trees: @@ -82,13 +83,20 @@ def run_snakemake(length_filter=False): f'snakemake', f'-s {SNAKEFILE_PATH}', f'--config {make_config(length_filter)}', - f'--cores {args.threads}', f'--rerun-incomplete', f'--keep-going', f'--nolock', f'--use-conda', ] + if args.profile is not None: + smk_frags.append(f'--profile {args.profile}') + else: + smk_frags.append(f'--cores {args.threads}') + + if args.dry_run: + smk_frags.append(f'-n') + smk_frags.append(get_output_files(length_filter)) smk_cmd = ' '.join(smk_frags) @@ -110,7 +118,12 @@ def run_snakemake(length_filter=False): # Optional Arguments optional.add_argument('-t', '--threads', metavar='N', type=int, default=1, help=textwrap.dedent("""\ - Desired number of threads to be utilized. + Total number of cores/threads for Snakemake to use for parallel job execution. + Default: 1 + """)) + optional.add_argument('--threads-per-job', metavar='N', type=int, default=1, + help=textwrap.dedent("""\ + Number of threads each individual tool (mafft, iqtree) should use. Default: 1 """)) optional.add_argument('--no_trees', action='store_true', @@ -130,6 +143,15 @@ def run_snakemake(length_filter=False): phylip-relaxed (names are not truncated), or nexus. Default: fasta """)) + optional.add_argument('--profile', metavar='', type=str, default=None, + help=textwrap.dedent("""\ + Snakemake cluster profile to use for running jobs. + Default: None + """)) + optional.add_argument('-n', '--dry-run', action='store_true', + help=textwrap.dedent("""\ + Perform a dry-run of the Snakemake workflow without executing any jobs. + """)) args = help_formatter.get_args(parser, optional, required, pre_suf=False, inp_dir=True) diff --git a/phylofisher/sgt_constructor.smk b/phylofisher/sgt_constructor.smk index 81c453c..3feb752 100644 --- a/phylofisher/sgt_constructor.smk +++ b/phylofisher/sgt_constructor.smk @@ -11,6 +11,7 @@ trees_only = config['trees_only'] no_trees = config['no_trees'] pf_database = config['database'] input_metadata = config['input_metadata'] +threads_per_job = config['threads_per_job'] # if not trees_only: @@ -40,20 +41,22 @@ rule length_filter_mafft: input: f'{out_dir}/prequal/{{gene}}.aa.filtered' output: - f'{out_dir}/length_filtration/mafft/{{gene}}.aln' + f'{out_dir}/length_filter_mafft/{{gene}}.aln' log: f'{out_dir}/logs/length_filter_mafft/{{gene}}.log' + threads: + threads_per_job conda: 'mafft.yaml' shell: - 'mafft --thread 1 --globalpair --maxiterate 1000 --unalignlevel 0.6 {input} >{output} 2>{log}' + 'mafft --thread {threads} --globalpair --maxiterate 1000 --unalignlevel 0.6 {input} >{output} 2>{log}' rule length_filter_divvier: input: - f'{out_dir}/length_filtration/mafft/{{gene}}.aln' + f'{out_dir}/length_filter_mafft/{{gene}}.aln' output: - f'{out_dir}/length_filtration/divvier/{{gene}}.aln.partial.fas', - f'{out_dir}/length_filtration/divvier/{{gene}}.aln.PP' + f'{out_dir}/length_filter_divvier/{{gene}}.aln.partial.fas', + f'{out_dir}/length_filter_divvier/{{gene}}.aln.PP' log: f'{out_dir}/logs/length_filter_divvier/{{gene}}.log' conda: @@ -62,15 +65,15 @@ rule length_filter_divvier: f''' divvier -mincol 4 -partial {{input}} >{{log}} 2>{{log}} - mv {out_dir}/length_filtration/mafft/{{wildcards.gene}}.aln.partial.fas {out_dir}/length_filtration/divvier &> {{log}} - mv {out_dir}/length_filtration/mafft/{{wildcards.gene}}.aln.PP {out_dir}/length_filtration/divvier &> {{log}} + mv {out_dir}/length_filter_mafft/{{wildcards.gene}}.aln.partial.fas {out_dir}/length_filter_divvier &> {{log}} + mv {out_dir}/length_filter_mafft/{{wildcards.gene}}.aln.PP {out_dir}/length_filter_divvier &> {{log}} ''' rule x_to_dash: input: - f'{out_dir}/length_filtration/divvier/{{gene}}.aln.partial.fas' + f'{out_dir}/length_filter_divvier/{{gene}}.aln.partial.fas' output: - f'{out_dir}/length_filtration/bmge/{{gene}}.pre_bmge' + f'{out_dir}/length_filter_bmge/{{gene}}.pre_bmge' log: f'{out_dir}/logs/x_to_dash/{{gene}}.log' run: @@ -80,9 +83,9 @@ rule x_to_dash: rule length_filter_bmge: input: - f'{out_dir}/length_filtration/bmge/{{gene}}.pre_bmge' + f'{out_dir}/length_filter_bmge/{{gene}}.pre_bmge' output: - f'{out_dir}/length_filtration/bmge/{{gene}}.bmge' + f'{out_dir}/length_filter_bmge/{{gene}}.bmge' log: f'{out_dir}/logs/length_filter_bmge/{{gene}}.log' conda: @@ -92,40 +95,49 @@ rule length_filter_bmge: rule length_filtration: input: - f'{out_dir}/length_filtration/bmge/{{gene}}.bmge' + f'{out_dir}/prequal/{{gene}}.aa', + f'{out_dir}/length_filter_bmge/{{gene}}.bmge' output: - f'{out_dir}/length_filtration/bmge/{{gene}}.length_filtered' + f'{out_dir}/length_filtered/{{gene}}.length_filtered' params: threshold=0.5 log: - f'{out_dir}/logs/length_filtration/{{gene}}.log' + f'{out_dir}/logs/length_filtered/{{gene}}.log' run: original_name = f'{wildcards.gene}.length_filtered' length = None - with open(output[0], 'w') as outfile, open(log[0], 'w') as logfile: - for record in SeqIO.parse(input[0], 'fasta'): + with open(log[0], 'w') as logfile: + ids_to_keep = [] + for record in SeqIO.parse(input[1], 'fasta'): if length is None: length = len(record.seq) - coverage = len(str(record.seq).replace('-', '').replace('X', '')) / len(record.seq) + coverage = len(str(record.seq).replace('-', '').replace('X', '')) / len(record.seq) if coverage > params.threshold: - outfile.write(f'>{record.description}\n{record.seq}\n') + ids_to_keep.append(record.description) else: - logfile.write(f'deleted: {record.name} {coverage}') + logfile.write(f'deleted: {record.name} {coverage}\n') if os.stat(input[0]).st_size == 0: logfile.write(f'All sequences were removed during length filtration') + with open(output[0], 'w') as outfile: + for record in SeqIO.parse(input[0], 'fasta'): + if record.description in ids_to_keep: + outfile.write(f'>{record.description}\n{str(record.seq)}\n') + rule mafft: input: - f'{out_dir}/length_filtration/bmge/{{gene}}.length_filtered' + f'{out_dir}/length_filtered/{{gene}}.length_filtered' output: f'{out_dir}/mafft/{{gene}}.aln' log: f'{out_dir}/logs/mafft/{{gene}}.log' + threads: + threads_per_job conda: 'mafft.yaml' shell: - 'mafft --thread 1 --globalpair --maxiterate 1000 --unalignlevel 0.6 {input} >{output} 2>{log}' + 'mafft --thread {threads} --globalpair --maxiterate 1000 --unalignlevel 0.6 {input} >{output} 2>{log}' rule divvier: input: @@ -174,7 +186,7 @@ rule remove_gaps: SeqIO.write(records, output[0], "fasta") -def get_raxml_input(wildcards): +def get_iqtree_input(wildcards): gene = '{wildcards.gene}'.format(wildcards=wildcards) if trees_only: return f'{in_dir}/{gene}.fas' @@ -183,11 +195,13 @@ def get_raxml_input(wildcards): rule iqtree: input: - get_raxml_input + get_iqtree_input output: f'{out_dir}/iqtree/{{gene}}.treefile' log: f'{out_dir}/logs/iqtree/{{gene}}.log' + threads: + threads_per_job conda: 'iqtree.yaml' params: @@ -195,12 +209,12 @@ rule iqtree: shell: ''' iqtree -s {input} \ - -pre {params.iqtree_out} \ - -m ELM+C20 \ - -B 1000 \ - -alrt 1000 \ - -T 1 \ - &> {log} + -pre {params.iqtree_out} \ + -m ELM+C20 \ + -B 1000 \ + -alrt 1000 \ + -T {threads} \ + &> {log} ''' if trees_only: @@ -224,7 +238,7 @@ if trees_only: else: rule cp_trees: input: - f'{out_dir}/length_filtration/bmge/{{gene}}.length_filtered', + f'{out_dir}/length_filter_bmge/{{gene}}.bmge', f'{out_dir}/trimal/{{gene}}.final', f'{out_dir}/iqtree/{{gene}}.treefile' output: diff --git a/setup.py b/setup.py index d8b5a4f..2cf1ea0 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='phylofisher', - version='1.2.14', + version='2.0.0', packages=find_packages(), scripts={'phylofisher/fisher.py', 'phylofisher/config.py', @@ -32,11 +32,13 @@ 'phylofisher/utilities/random_resampler.py', 'phylofisher/utilities/astral_runner.py', 'phylofisher/utilities/rtc_binner.py', + 'phylofisher/utilities/leaf_renamer.py', 'phylofisher/utilities/backup_restoration.py', 'phylofisher/utilities/explore_database.py', 'phylofisher/utilities/nucl_matrix_constructor.py', 'phylofisher/utilities/gfmix_runner.py', 'phylofisher/utilities/gfmix_mammal.smk', + 'phylofisher/utilities/dataset_to_database.py', 'phylofisher/gfmix.yaml', 'phylofisher/mammal.yaml', 'phylofisher/prequal.yaml', @@ -53,6 +55,5 @@ license='MIT', author='David Zihala', author_email='zihaladavid@gmail.com', - description='PhyloFisher is a software package for the creation, analysis, and visualization of phylogenomic ' - 'datasets that consist of protein sequences from eukaryotic organisms.' + description='PhyloFisher is a software package for the creation, analysis, and visualization of phylogenomic datasets that consist of protein sequences from eukaryotic organisms.' )