Skip to content

Commit

Permalink
Fix bug in PGENReader, add tests for fields parameter, and remove unu…
Browse files Browse the repository at this point in the history
…sed n_jobs parameter
  • Loading branch information
salcc committed Nov 21, 2024
1 parent 00f969d commit f35e862
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
51 changes: 51 additions & 0 deletions snputils/snp/io/read/__test__/test_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np

from snputils import VCFReader, BEDReader, PGENReader


# TODO: Fails with KeyError: 'calldata/GT' (genotypes = vcf_dict["calldata/GT"].astype(np.int8))
# def test_vcf_only_samples(data_path, snpobj_vcf):
# snpobj_vcf_only_samples = VCFReader(data_path + "/subset.vcf").read(fields=["IID"])
# assert np.array_equal(snpobj_vcf_only_samples.samples, snpobj_vcf.samples)


def test_bed_only_samples(data_path, snpobj_bed):
snpobj_bed_only_samples = BEDReader(data_path + "/bed/subset").read(fields=["IID"])
assert np.array_equal(snpobj_bed_only_samples.samples, snpobj_bed.samples)


def test_pgen_only_samples(data_path, snpobj_pgen):
snpobj_pgen_only_samples = PGENReader(data_path + "/pgen/subset").read(fields=["IID"])
assert np.array_equal(snpobj_pgen_only_samples.samples, snpobj_pgen.samples)


# TODO: Fails with KeyError: 'calldata/GT' (genotypes = vcf_dict["calldata/GT"].astype(np.int8))
# def test_vcf_only_variants(data_path, snpobj_vcf):
# snpobj_vcf_only_variants = VCFReader(data_path + "/subset.vcf").read(fields=["ID"])
# assert np.array_equal(snpobj_vcf_only_variants.variants_id, snpobj_vcf.variants_id)


def test_bed_only_variants(data_path, snpobj_bed):
snpobj_bed_only_variants = BEDReader(data_path + "/bed/subset").read(fields=["ID"])
assert np.array_equal(snpobj_bed_only_variants.variants_id, snpobj_bed.variants_id)


def test_pgen_only_variants(data_path, snpobj_pgen):
snpobj_pgen_only_variants = PGENReader(data_path + "/pgen/subset").read(fields=["ID"])
assert np.array_equal(snpobj_pgen_only_variants.variants_id, snpobj_pgen.variants_id)


# TODO: Fails with KeyError: 'samples' (samples=vcf_dict["samples"],)
# def test_vcf_only_gt(data_path, snpobj_vcf):
# snpobj_vcf_only_gt = VCFReader(data_path + "/subset.vcf").read(fields=["GT"])
# assert np.array_equal(snpobj_vcf_only_gt.calldata_gt, snpobj_vcf.calldata_gt)


def test_bed_only_gt(data_path, snpobj_bed):
snpobj_bed_only_gt = BEDReader(data_path + "/bed/subset").read(fields=["GT"])
assert np.array_equal(snpobj_bed_only_gt.calldata_gt, snpobj_bed.calldata_gt)


def test_pgen_only_gt(data_path, snpobj_pgen):
snpobj_pgen_only_gt = PGENReader(data_path + "/pgen/subset").read(fields=["GT"])
assert np.array_equal(snpobj_pgen_only_gt.calldata_gt, snpobj_pgen.calldata_gt)
10 changes: 3 additions & 7 deletions snputils/snp/io/read/bed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def read(
variant_ids: Optional[np.ndarray] = None,
variant_idxs: Optional[np.ndarray] = None,
sum_strands: bool = False,
n_jobs: int = 1
) -> SNPObject:
"""
Read a bed fileset (bed, bim, fam) into a SNPObject.
Expand Down Expand Up @@ -174,12 +173,9 @@ def read(

snpobj = SNPObject(
calldata_gt=genotypes if "GT" in fields else None,
samples=fam.get_column("IID").to_numpy() if "IID" in fields else None,
variants_ref=bim.get_column("REF").to_numpy() if "REF" in fields else None,
variants_alt=bim.get_column("ALT").to_numpy() if "ALT" in fields else None,
variants_chrom=bim.get_column("#CHROM").to_numpy() if "#CHROM" in fields else None,
variants_id=bim.get_column("ID").to_numpy() if "ID" in fields else None,
variants_pos=bim.get_column("POS").to_numpy() if "POS" in fields else None,
samples=fam.get_column("IID").to_numpy() if "IID" in fields and "IID" in fam.columns else None,
**{f'variants_{k.lower()}': bim.get_column(v).to_numpy() if v in fields and v in bim.columns else None
for k, v in {'ref': 'REF', 'alt': 'ALT', 'chrom': '#CHROM', 'id': 'ID', 'pos': 'POS'}.items()}
)

log.info("Finished constructing SNPObject")
Expand Down
12 changes: 3 additions & 9 deletions snputils/snp/io/read/pgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def read(
variant_ids: Optional[np.ndarray] = None,
variant_idxs: Optional[np.ndarray] = None,
sum_strands: bool = False,
n_jobs: int = 1
) -> SNPObject:
"""
Read a pgen fileset (pgen, psam, pvar) into a SNPObject.
Expand Down Expand Up @@ -217,14 +216,9 @@ def open_textfile(filename):

snpobj = SNPObject(
calldata_gt=genotypes if "GT" in fields else None,
samples=psam.get_column("IID").to_numpy() if "IID" in psam.columns and fields else None,
variants_ref=pvar.get_column("REF").to_numpy() if "REF" in pvar.columns and fields else None,
variants_alt=pvar.get_column("ALT").to_numpy() if "ALT" in pvar.columns and fields else None,
variants_chrom=pvar.get_column("#CHROM").to_numpy() if "#CHROM" in pvar.columns and fields else None,
variants_id=pvar.get_column("ID").to_numpy() if "ID" in pvar.columns and fields else None,
variants_pos=pvar.get_column("POS").to_numpy() if "POS" in pvar.columns and fields else None,
variants_filter_pass=pvar.get_column("FILTER").to_numpy() if "FILTER" in pvar.columns and fields else None,
variants_qual=pvar.get_column("QUAL").to_numpy() if "QUAL" in pvar.columns and fields else None,
samples=psam.get_column("IID").to_numpy() if "IID" in fields and "IID" in psam.columns else None,
**{f'variants_{k.lower()}': pvar.get_column(v).to_numpy() if v in fields and v in pvar.columns else None
for k, v in {'ref': 'REF', 'alt': 'ALT', 'chrom': '#CHROM', 'id': 'ID', 'pos': 'POS', 'filter_pass': 'FILTER', 'qual': 'QUAL'}.items()}
)

log.info("Finished constructing SNPObject")
Expand Down

0 comments on commit f35e862

Please sign in to comment.