diff --git a/gnomad/utils/constraint.py b/gnomad/utils/constraint.py index b605474d5..90dd06162 100644 --- a/gnomad/utils/constraint.py +++ b/gnomad/utils/constraint.py @@ -62,6 +62,7 @@ def count_variants_by_group( use_table_group_by: bool = False, singleton_expr: Optional[hl.expr.BooleanExpression] = None, max_af: Optional[float] = None, + skip_downsamplings: bool = False, ) -> Union[hl.Table, Any]: """ Count number of observed or possible variants by context, ref, alt, and optionally methylation_level. @@ -145,6 +146,7 @@ def count_variants_by_group( [0].AC == 1`. Default is None. :param max_af: Maximum variant allele frequency to keep. By default, no cutoff is applied. + :param skip_downsamplings: Whether or not to skip pulling the downsampling data. :return: Table including 'variant_count' annotation and if requested, `singleton_count` and downsampling counts. """ @@ -206,12 +208,13 @@ def count_variants_by_group( pop, pop, ) - agg[f"downsampling_counts_{pop}"] = downsampling_counts_expr( + agg[f"downsampling_counts_{pop}"] = pop_counts_expr( freq_expr, freq_meta_expr, pop, max_af=max_af, downsamplings=downsamplings, + skip_downsamplings=skip_downsamplings, ) if count_singletons: logger.info( @@ -220,12 +223,13 @@ def count_variants_by_group( pop, pop, ) - agg[f"singleton_downsampling_counts_{pop}"] = downsampling_counts_expr( + agg[f"singleton_downsampling_counts_{pop}"] = pop_counts_expr( freq_expr, freq_meta_expr, pop, max_af=max_af, downsamplings=downsamplings, + skip_downsamplings=skip_downsamplings, singleton=True, ) # Apply each variant count aggregation in `agg` to get counts for all @@ -238,16 +242,17 @@ def count_variants_by_group( ) -def get_downsampling_freq_indices( +def get_pop_freq_indices( freq_meta_expr: hl.expr.ArrayExpression, pop: str = "global", variant_quality: str = "adj", genetic_ancestry_label: Optional[str] = None, subset: Optional[str] = None, downsamplings: Optional[List[int]] = None, + skip_downsamplings: bool = False, ) -> hl.expr.ArrayExpression: """ - Get indices of dictionaries in meta dictionaries that only have the "downsampling" key with specified `genetic_ancestry_label` and "variant_quality" values. + Get indices of dictionaries in meta dictionaries with specified `genetic_ancestry_label`, `variant_quality` values, and downsamplings if specified. :param freq_meta_expr: ArrayExpression containing the set of groupings for each element of the `freq_expr` array (e.g., [{'group': 'adj'}, {'group': 'adj', @@ -264,6 +269,7 @@ def get_downsampling_freq_indices( key in `freq_meta_expr`. :param downsamplings: Optional List of integers specifying what downsampling indices to obtain. Default is None, which will return all downsampling indices. + :param skip_downsamplings: Whether or not to skip pulling the downsampling data. :return: ArrayExpression of indices of dictionaries in `freq_meta_expr` that only have the "downsampling" key with specified `genetic_ancestry_label` and "variant_quality" values. @@ -277,12 +283,17 @@ def _get_filter_expr(m: hl.expr.StructExpression) -> hl.expr.BooleanExpression: filter_expr = ( (m.get("group") == variant_quality) & (hl.any([m.get(l, "") == pop for l in gen_anc])) - & m.contains("downsampling") + & ~m.contains("sex") ) - if downsamplings is not None: - filter_expr &= hl.literal(downsamplings).contains( - hl.int(m.get("downsampling", "0")) - ) + + if skip_downsamplings: + filter_expr &= ~m.contains("downsampling") + else: + if downsamplings is not None: + filter_expr &= hl.literal(downsamplings).contains( + hl.int(m.get("downsampling", "0")) + ) + if subset is None: filter_expr &= ~m.contains("subset") else: @@ -291,11 +302,12 @@ def _get_filter_expr(m: hl.expr.StructExpression) -> hl.expr.BooleanExpression: indices = hl.enumerate(freq_meta_expr).filter(lambda f: _get_filter_expr(f[1])) - # Get an array of indices and meta dictionaries sorted by "downsampling" key. - return hl.sorted(indices, key=lambda f: hl.int(f[1]["downsampling"])) + # Get an array of indices and meta dictionaries sorted by "downsampling" + # key if present. + return hl.sorted(indices, key=lambda f: hl.int(f[1].get("downsampling", "0"))) -def downsampling_counts_expr( +def pop_counts_expr( freq_expr: hl.expr.ArrayExpression, freq_meta_expr: hl.expr.ArrayExpression, pop: str = "global", @@ -305,6 +317,7 @@ def downsampling_counts_expr( genetic_ancestry_label: Optional[str] = None, subset: Optional[str] = None, downsamplings: Optional[List[int]] = None, + skip_downsamplings: bool = False, ) -> hl.expr.ArrayExpression: """ Return an aggregation expression to compute an array of counts of all downsamplings found in `freq_expr` where specified criteria is met. @@ -335,17 +348,19 @@ def downsampling_counts_expr( subset will be included. :param downsamplings: Optional List of integers specifying what downsampling indices to obtain. Default is None, which will return all downsampling counts. + :param skip_downsamplings: Whether of not to skip pulling the downsampling data. :return: Aggregation Expression for an array of the variant counts in downsamplings for specified population. """ # Get an array of indices sorted by "downsampling" key. - sorted_indices = get_downsampling_freq_indices( + sorted_indices = get_pop_freq_indices( freq_meta_expr, pop, variant_quality, genetic_ancestry_label, subset, downsamplings, + skip_downsamplings, ).map(lambda x: x[0]) def _get_criteria(i: hl.expr.Int32Expression) -> hl.expr.Int32Expression: @@ -1148,6 +1163,15 @@ def oe_aggregation_expr( agg_expr["gen_anc_obs"] = hl.struct( **{pop: hl.agg.array_sum(ht[f"downsampling_counts_{pop}"]) for pop in pops} ) + agg_expr["gen_anc_oe"] = hl.struct( + **{ + pop: hl.map( + lambda x: divide_null(x[0], x[1]), + hl.zip(agg_expr["gen_anc_obs"][pop], agg_expr["gen_anc_exp"][pop]), + ) + for pop in pops + } + ) agg_expr = hl.struct(**agg_expr) return hl.agg.group_by(filter_expr, agg_expr).get(True, hl.missing(agg_expr.dtype))