Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions gnomad/utils/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def count_variants_by_group(
use_table_group_by: bool = False,
singleton_expr: Optional[hl.expr.BooleanExpression] = None,
max_af: Optional[float] = None,
skip_downsamplings: bool = False,
) -> Union[hl.Table, Any]:
"""
Count number of observed or possible variants by context, ref, alt, and optionally methylation_level.
Expand Down Expand Up @@ -145,6 +146,7 @@ def count_variants_by_group(
[0].AC == 1`. Default is None.
:param max_af: Maximum variant allele frequency to keep. By default, no cutoff is
applied.
:param skip_downsamplings: Whether or not to skip pulling the downsampling data.
:return: Table including 'variant_count' annotation and if requested,
`singleton_count` and downsampling counts.
"""
Expand Down Expand Up @@ -206,12 +208,13 @@ def count_variants_by_group(
pop,
pop,
)
agg[f"downsampling_counts_{pop}"] = downsampling_counts_expr(
agg[f"downsampling_counts_{pop}"] = pop_counts_expr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If skipping downsamplings, we probably don't want to name it downsampling_counts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what if both downsamplings and a split by pop is wanted? It seems like this should be an addition rather than a skip

freq_expr,
freq_meta_expr,
pop,
max_af=max_af,
downsamplings=downsamplings,
skip_downsamplings=skip_downsamplings,
)
if count_singletons:
logger.info(
Expand All @@ -220,12 +223,13 @@ def count_variants_by_group(
pop,
pop,
)
agg[f"singleton_downsampling_counts_{pop}"] = downsampling_counts_expr(
agg[f"singleton_downsampling_counts_{pop}"] = pop_counts_expr(
freq_expr,
freq_meta_expr,
pop,
max_af=max_af,
downsamplings=downsamplings,
skip_downsamplings=skip_downsamplings,
singleton=True,
)
# Apply each variant count aggregation in `agg` to get counts for all
Expand All @@ -238,16 +242,17 @@ def count_variants_by_group(
)


def get_downsampling_freq_indices(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we really should be making use of filter_arrays_by_meta, which didn't exist when we added this in. We can remove this function since it's likely not used much outside this module, or if we are worried about doing that, we can use from deprecated import deprecated See below for the suggested use of filter_arrays_by_meta

def get_pop_freq_indices(
freq_meta_expr: hl.expr.ArrayExpression,
pop: str = "global",
variant_quality: str = "adj",
genetic_ancestry_label: Optional[str] = None,
subset: Optional[str] = None,
downsamplings: Optional[List[int]] = None,
skip_downsamplings: bool = False,
) -> hl.expr.ArrayExpression:
"""
Get indices of dictionaries in meta dictionaries that only have the "downsampling" key with specified `genetic_ancestry_label` and "variant_quality" values.
Get indices of dictionaries in meta dictionaries with specified `genetic_ancestry_label`, `variant_quality` values, and downsamplings if specified.

:param freq_meta_expr: ArrayExpression containing the set of groupings for each
element of the `freq_expr` array (e.g., [{'group': 'adj'}, {'group': 'adj',
Expand All @@ -264,6 +269,7 @@ def get_downsampling_freq_indices(
key in `freq_meta_expr`.
:param downsamplings: Optional List of integers specifying what downsampling
indices to obtain. Default is None, which will return all downsampling indices.
:param skip_downsamplings: Whether or not to skip pulling the downsampling data.
:return: ArrayExpression of indices of dictionaries in `freq_meta_expr` that only
have the "downsampling" key with specified `genetic_ancestry_label` and
"variant_quality" values.
Expand All @@ -277,12 +283,17 @@ def _get_filter_expr(m: hl.expr.StructExpression) -> hl.expr.BooleanExpression:
filter_expr = (
(m.get("group") == variant_quality)
& (hl.any([m.get(l, "") == pop for l in gen_anc]))
& m.contains("downsampling")
& ~m.contains("sex")
)
if downsamplings is not None:
filter_expr &= hl.literal(downsamplings).contains(
hl.int(m.get("downsampling", "0"))
)

if skip_downsamplings:
filter_expr &= ~m.contains("downsampling")
else:
if downsamplings is not None:
filter_expr &= hl.literal(downsamplings).contains(
hl.int(m.get("downsampling", "0"))
)

if subset is None:
filter_expr &= ~m.contains("subset")
else:
Expand All @@ -291,11 +302,12 @@ def _get_filter_expr(m: hl.expr.StructExpression) -> hl.expr.BooleanExpression:

indices = hl.enumerate(freq_meta_expr).filter(lambda f: _get_filter_expr(f[1]))

# Get an array of indices and meta dictionaries sorted by "downsampling" key.
return hl.sorted(indices, key=lambda f: hl.int(f[1]["downsampling"]))
# Get an array of indices and meta dictionaries sorted by "downsampling"
# key if present.
return hl.sorted(indices, key=lambda f: hl.int(f[1].get("downsampling", "0")))


def downsampling_counts_expr(
def pop_counts_expr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just change to get_counts_expr

freq_expr: hl.expr.ArrayExpression,
freq_meta_expr: hl.expr.ArrayExpression,
pop: str = "global",
Expand All @@ -305,6 +317,7 @@ def downsampling_counts_expr(
genetic_ancestry_label: Optional[str] = None,
subset: Optional[str] = None,
downsamplings: Optional[List[int]] = None,
skip_downsamplings: bool = False,
) -> hl.expr.ArrayExpression:
"""
Return an aggregation expression to compute an array of counts of all downsamplings found in `freq_expr` where specified criteria is met.
Expand Down Expand Up @@ -335,17 +348,19 @@ def downsampling_counts_expr(
subset will be included.
:param downsamplings: Optional List of integers specifying what downsampling
indices to obtain. Default is None, which will return all downsampling counts.
:param skip_downsamplings: Whether of not to skip pulling the downsampling data.
:return: Aggregation Expression for an array of the variant counts in downsamplings
for specified population.
"""
# Get an array of indices sorted by "downsampling" key.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would change these lines to this to use filter_arrays_by_meta

might need to add an option for sex if you don't want to include it, or something to direct match.

Could still split out the filtering/sorting portion from the counts function if wanted, but then I think the count function should really just take the already filtered function.

Suggested change
# Get an array of indices sorted by "downsampling" key.
# Determine the genetic ancestry label to use for filtering by the
# `genetic_ancestry_label` key in `freq_meta_expr`.
genetic_ancestry_label = genetic_ancestry_label or hl.eval(
hl.literal(["gen_anc", "pop"]).find(
lambda c: freq_meta_expr.flatmap(lambda x: x.keys()).contains(c)
)
)
# Build filters to apply to `freq_meta_expr` to get the desired metadata groups for
# the aggregate count expression.
filters = {
"group": [variant_quality],
genetic_ancestry_label: [pop],
**({"subset": [subset]} if subset else {}),
**(
{"downsampling": list(map(str, downsamplings))}
if not skip_downsamplings and downsamplings is not None
else {}
)
}
filters = [(filters, {})]
if subset is None:
filters.append((["subset"], {"keep": False}))
if skip_downsamplings:
filters.append((["downsampling"], {"keep": False}))
elif downsamplings is None:
filters.append((["downsampling"], {}))
# Apply filters to `freq_meta_expr` and it's indices to get the indices of the
# desired metadata groups.
filtered_meta = freq_meta_expr
indices = hl.enumerate(freq_meta_expr)
for f, params in filters:
filtered_meta, filtered_idx = filter_arrays_by_meta(
filtered_meta,
indices,
items_to_filter=f,
**params
)
# Get an array of indices and meta dictionaries sorted by "downsampling" key if
# downsamplings are not skipped.
if not skip_downsamplings:
indices = hl.sorted(indices, key=lambda f: hl.int(f[1]["downsampling"]))
def _get_criteria(i: hl.expr.Int32Expression) -> hl.expr.Int32Expression:
"""
Return 1 when variant meets specified criteria (`singleton` or `max_af`), if requested, or with an AC > 0.
:param i: The index of a downsampling.
:return: Returns 1 if the variant in the downsampling with specified index met
the criteria. Otherwise, returns 0.
"""
if singleton:
return hl.int(freq_expr[i].AC == 1)
elif max_af:
return hl.int((freq_expr[i].AC > 0) & (freq_expr[i].AF <= max_af))
else:
return hl.int(freq_expr[i].AC > 0)
# Map `_get_criteria` function to each downsampling indexed by `sorted_indices` to
# generate a list of 1's and 0's for each variant, where the length of the array is
# the total number of downsamplings for the specified population and each element
# in the array indicates if the variant in the downsampling indexed by
# `sorted_indices` meets the specified criteria.
# Return an array sum aggregation that aggregates arrays generated from mapping.
return hl.agg.array_sum(hl.map(_get_criteria, indices.map(lambda x: x[0])))

sorted_indices = get_downsampling_freq_indices(
sorted_indices = get_pop_freq_indices(
freq_meta_expr,
pop,
variant_quality,
genetic_ancestry_label,
subset,
downsamplings,
skip_downsamplings,
).map(lambda x: x[0])

def _get_criteria(i: hl.expr.Int32Expression) -> hl.expr.Int32Expression:
Expand Down Expand Up @@ -1148,6 +1163,15 @@ def oe_aggregation_expr(
agg_expr["gen_anc_obs"] = hl.struct(
**{pop: hl.agg.array_sum(ht[f"downsampling_counts_{pop}"]) for pop in pops}
)
agg_expr["gen_anc_oe"] = hl.struct(
**{
pop: hl.map(
lambda x: divide_null(x[0], x[1]),
hl.zip(agg_expr["gen_anc_obs"][pop], agg_expr["gen_anc_exp"][pop]),
)
for pop in pops
}
)

agg_expr = hl.struct(**agg_expr)
return hl.agg.group_by(filter_expr, agg_expr).get(True, hl.missing(agg_expr.dtype))
Expand Down