Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions gnomad_constraint/pipeline/constraint_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def main(args):
version = args.version
test = args.test
overwrite = args.overwrite
skip_downsamplings = args.skip_downsamplings

max_af = args.max_af
pops = args.pops
Expand All @@ -262,6 +263,10 @@ def main(args):
)
logger.info("The following downsamplings will be used: %s", downsamplings)

# If pops not specified, set to empty Tuple
if not pops:
pops = ()

# Drop chromosome Y from version v4.0 (can add back in when obtain chrY
# methylation data).
if int(version[0]) >= 4:
Expand Down Expand Up @@ -422,10 +427,15 @@ def main(args):
"mane_select" if version_4_and_above else "canonical"
), # Switch to using MANE Select transcripts rather than canonical for gnomAD v4 and later versions.
global_annotation="training_dataset_params",
skip_downsamplings=skip_downsamplings,
)
if use_v2_release_mutation_ht:
op_ht = op_ht.annotate_globals(use_v2_release_mutation_ht=True)
op_ht.write(getattr(res, f"train_{r}_ht").path, overwrite=overwrite)
# op_ht.write(getattr(res, f"train_{r}_ht").path, overwrite=overwrite)
op_ht.write(
"gs://gnomad-kristen/constraint/gen_anc/train.ht",
overwrite=overwrite,
)
logger.info("Done with creating training dataset.")

if args.build_models:
Expand All @@ -436,7 +446,10 @@ def main(args):
# chromosome X, and chromosome Y.
for r in regions:
# TODO: Remove repartition once partition_hint bugs are resolved.
training_ht = getattr(res, f"train_{r}_ht").ht()
# training_ht = getattr(res, f"train_{r}_ht").ht()
training_ht = hl.read_table(
"gs://gnomad-kristen/constraint/gen_anc/train.ht"
)
training_ht = training_ht.repartition(args.training_set_partition_hint)

logger.info("Building %s plateau and coverage models...", r)
Expand All @@ -450,15 +463,16 @@ def main(args):
)
hl.experimental.write_expression(
plateau_models,
getattr(res, f"model_{r}_plateau").path,
"gs://gnomad-kristen/constraint/gen_anc/plateau_models.he",
# getattr(res, f"model_{r}_plateau").path,
overwrite=overwrite,
)
if not args.skip_coverage_model:
hl.experimental.write_expression(
coverage_model,
getattr(res, f"model_{r}_coverage").path,
overwrite=overwrite,
)
# if not args.skip_coverage_model:
# hl.experimental.write_expression(
# coverage_model,
# getattr(res, f"model_{r}_coverage").path,
# overwrite=overwrite,
# )
logger.info("Done building %s models.", r)

if args.apply_models:
Expand Down Expand Up @@ -486,7 +500,10 @@ def main(args):
exome_ht=getattr(res, f"preprocessed_{r}_exomes_ht").ht(),
context_ht=getattr(res, f"preprocessed_{r}_context_ht").ht(),
mutation_ht=mutation_ht,
plateau_models=getattr(res, f"model_{r}_plateau").he(),
plateau_models=hl.experimental.read_expression(
"gs://gnomad-kristen/constraint/gen_anc/plateau_models.he"
),
# plateau_models=getattr(res, f"model_{r}_plateau").he(),
coverage_model=(
getattr(res, "model_autosome_par_coverage").he()
if not args.skip_coverage_model
Expand All @@ -495,6 +512,7 @@ def main(args):
max_af=max_af,
pops=pops,
downsamplings=downsamplings,
skip_downsamplings=skip_downsamplings,
obs_pos_count_partition_hint=args.apply_obs_pos_count_partition_hint,
expected_variant_partition_hint=args.apply_expected_variant_partition_hint,
custom_vep_annotation=custom_vep_annotation,
Expand All @@ -509,7 +527,8 @@ def main(args):
)
if use_v2_release_mutation_ht:
oe_ht = oe_ht.annotate_globals(use_v2_release_mutation_ht=True)
oe_ht.write(getattr(res, f"apply_{r}_ht").path, overwrite=overwrite)
# oe_ht.write(getattr(res, f"apply_{r}_ht").path, overwrite=overwrite)
oe_ht.write("gs://gnomad-kristen/constraint/gen_anc/apply.ht")

logger.info(
"Done computing expected variant count and observed:expected ratio."
Expand Down Expand Up @@ -992,6 +1011,11 @@ def main(args):
help="Export constraint metrics to tsv file.",
action="store_true",
)
parser.add_argument(
"--skip-downsamplings",
help="Whether to skip downsamplings when 'pops' is specified.",
action="store_true",
)

compute_constraint_args._group_actions.append(populations)

Expand Down
47 changes: 38 additions & 9 deletions gnomad_constraint/utils/constraint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Script containing utility functions used in the constraint pipeline."""

import logging
from typing import Dict, List, Optional, Tuple

Expand All @@ -16,7 +17,7 @@
compute_pli,
count_variants_by_group,
get_constraint_flags,
get_downsampling_freq_indices,
get_pop_freq_indices,
oe_aggregation_expr,
oe_confidence_interval,
trimer_from_heptamer,
Expand Down Expand Up @@ -182,6 +183,7 @@ def create_observed_and_possible_ht(
low_coverage_filter: int = None,
transcript_for_synonymous_filter: str = None,
global_annotation: Optional[str] = None,
skip_downsamplings: bool = False,
) -> hl.Table:
"""
Count the observed variants and possible variants by substitution, context, methylation level, and additional `grouping`.
Expand Down Expand Up @@ -238,6 +240,7 @@ def create_observed_and_possible_ht(
:param global_annotation: The annotation name to use as a global StructExpression
annotation containing input parameter values. If no value is supplied, this
global annotation will not be added. Default is None.
:param skip_downsamplings: Whether or not to skip pulling the downsampling data.
:return: Table with observed variant and possible variant count.
"""
if low_coverage_filter is not None:
Expand Down Expand Up @@ -292,6 +295,7 @@ def create_observed_and_possible_ht(
count_downsamplings=pops,
use_table_group_by=True,
max_af=max_af,
skip_downsamplings=skip_downsamplings,
)

# TODO: Remove repartition once partition_hint bugs are resolved.
Expand Down Expand Up @@ -353,6 +357,7 @@ def apply_models(
high_cov_definition: int = COVERAGE_CUTOFF,
low_coverage_filter: int = None,
use_mane_select: bool = True,
skip_downsamplings: bool = False,
) -> hl.Table:
"""
Compute the expected number of variants and observed:expected ratio using plateau models and coverage model.
Expand Down Expand Up @@ -426,6 +431,7 @@ def apply_models(
:param use_mane_select: Use MANE Select transcripts in grouping.
Only used when `custom_vep_annotation` is set to 'transcript_consequences'.
Default is True.
:param skip_downsamplings: Whether or not to skip pulling the downsampling data.

:return: Table with `expected_variants` (expected variant counts) and `obs_exp`
(observed:expected ratio) annotations.
Expand Down Expand Up @@ -477,6 +483,7 @@ def apply_models(
partition_hint=obs_pos_count_partition_hint,
filter_coverage_over_0=True,
transcript_for_synonymous_filter=None,
skip_downsamplings=skip_downsamplings,
)

# NOTE: In v2 ht.mu_snp was incorrectly multiplied here by possible_variants, but this multiplication has now been moved,
Expand Down Expand Up @@ -524,15 +531,15 @@ def apply_models(

# Store which downsamplings are obtained for each pop in a
# downsampling_meta dictionary.
ds = hl.eval(get_downsampling_freq_indices(ht.freq_meta, pop=pop))
ds = hl.eval(get_pop_freq_indices(ht.freq_meta, pop=pop))
key_names = {key for _, meta_dict in ds for key in meta_dict.keys()}
genetic_ancestry_label = "gen_anc" if "gen_anc" in key_names else "pop"
downsampling_meta[pop] = [
x[1]["downsampling"]
x[1].get("downsampling", "all")
for x in ds
if (x[1][genetic_ancestry_label] == pop)
& (
int(x[1]["downsampling"]) in downsamplings
if x[1][genetic_ancestry_label] == pop
and (
int(x[1].get("downsampling", 0)) in downsamplings
if downsamplings is not None
else True
)
Expand Down Expand Up @@ -897,9 +904,8 @@ def compute_constraint_metrics(
# `annotation_dict` stats the rule of filtration for each annotation.
annotation_dict = {
# Filter to classic LoF annotations with LOFTEE HC or LC.
"lof_hc_lc": hl.literal(set(classic_lof_annotations)).contains(
ht.annotation
) & ((ht.modifier == "HC") | (ht.modifier == "LC")),
"lof_hc_lc": hl.literal(set(classic_lof_annotations)).contains(ht.annotation)
& ((ht.modifier == "HC") | (ht.modifier == "LC")),
# Filter to LoF annotations with LOFTEE HC.
"lof": ht.modifier == "HC",
# Filter to missense variants.
Expand Down Expand Up @@ -1009,6 +1015,29 @@ def compute_constraint_metrics(
z_raw=raw_z_expr,
)

gen_anc_lower_struct = {}
gen_anc_upper_struct = {}
gen_anc_z_raw_struct = {}

# Calculate lower and upper cis, and raw z scores for each pop, excluding downsamplings.
for pop in pops:
obs_expr = ht[ann]["gen_anc_obs"][pop][0]
exp_expr = ht[ann]["gen_anc_exp"][pop][0]
oe_ci_expr = oe_confidence_interval(obs_expr, exp_expr)
raw_z_expr = calculate_raw_z_score(obs_expr, exp_expr)

lower_struct[pop] = oe_ci_expr.lower
upper_struct[pop] = oe_ci_expr.upper
gen_anc_z_raw_struct[pop] = raw_z_expr

# Annotate the table with the structs.
ann_expr[ann] = ann_expr[ann].annotate(
gen_anc_oe_ci=hl.struct(
lower=hl.struct(**lower_struct), upper=hl.struct(**upper_struct)
),
gen_anc_z_raw=hl.struct(**gen_anc_z_raw_struct),
)

ann_expr["constraint_flags"] = add_filters_expr(filters=constraint_flags_expr)
ht = ht.annotate(**ann_expr)
ht = ht.checkpoint(
Expand Down