@@ -829,7 +829,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
829
829
return False
830
830
831
831
832
- def convert_local_allele_field_types (fields ):
832
+ def convert_local_allele_field_types (fields , schema_instance ):
833
833
"""
834
834
Update the specified list of fields to include the LAA field, and to convert
835
835
any supported localisable fields to the L* counterpart.
@@ -842,45 +842,45 @@ def convert_local_allele_field_types(fields):
842
842
"""
843
843
fields_by_name = {field .name : field for field in fields }
844
844
gt = fields_by_name ["call_genotype" ]
845
- if gt .shape [- 1 ] != 2 :
846
- raise ValueError ("Local alleles only supported on diploid data" )
847
845
848
- # TODO check if LA is already in here
846
+ if schema_instance .get_shape (["ploidy" ])[0 ] != 2 :
847
+ raise ValueError ("Local alleles only supported on diploid data" )
849
848
850
- shape = gt .shape [:- 1 ]
851
- chunks = gt .chunks [:- 1 ]
852
849
dimensions = gt .dimensions [:- 1 ]
853
850
854
851
la = vcz .ZarrArraySpec (
855
852
name = "call_LA" ,
856
853
dtype = "i1" ,
857
- shape = gt .shape ,
858
- chunks = gt .chunks ,
859
854
dimensions = (* dimensions , "local_alleles" ),
860
855
description = (
861
856
"0-based indices into REF+ALT, indicating which alleles"
862
857
" are relevant (local) for the current sample"
863
858
),
864
859
)
860
+ schema_instance .dimensions ["local_alleles" ] = vcz .VcfZarrDimension (
861
+ size = schema_instance .dimensions ["ploidy" ].size
862
+ )
863
+
865
864
ad = fields_by_name .get ("call_AD" , None )
866
865
if ad is not None :
867
866
# TODO check if call_LAD is in the list already
868
867
ad .name = "call_LAD"
869
868
ad .source = None
870
- ad .shape = (* shape , 2 )
871
- ad .chunks = (* chunks , 2 )
872
- ad .dimensions = (* dimensions , "local_alleles" )
869
+ ad .dimensions = (* dimensions , "local_alleles_AD" )
873
870
ad .description += " (local-alleles)"
871
+ schema_instance .dimensions ["local_alleles_AD" ] = vcz .VcfZarrDimension (size = 2 )
874
872
875
873
pl = fields_by_name .get ("call_PL" , None )
876
874
if pl is not None :
877
875
# TODO check if call_LPL is in the list already
878
876
pl .name = "call_LPL"
879
877
pl .source = None
880
- pl .shape = (* shape , 3 )
881
- pl .chunks = (* chunks , 3 )
882
878
pl .description += " (local-alleles)"
883
- pl .dimensions = (* dimensions , "local_" + pl .dimensions [- 1 ])
879
+ pl .dimensions = (* dimensions , "local_" + pl .dimensions [- 1 ].split ("_" )[- 1 ])
880
+ schema_instance .dimensions ["local_" + pl .dimensions [- 1 ].split ("_" )[- 1 ]] = (
881
+ vcz .VcfZarrDimension (size = 3 )
882
+ )
883
+
884
884
return [* fields , la ]
885
885
886
886
@@ -1042,36 +1042,40 @@ def generate_schema(
1042
1042
if local_alleles is None :
1043
1043
local_alleles = False
1044
1044
1045
+ dimensions = {
1046
+ "variants" : vcz .VcfZarrDimension (
1047
+ size = m , chunk_size = variants_chunk_size or 1000
1048
+ ),
1049
+ "samples" : vcz .VcfZarrDimension (
1050
+ size = n , chunk_size = samples_chunk_size or 10000
1051
+ ),
1052
+ # ploidy added conditionally below
1053
+ "alleles" : vcz .VcfZarrDimension (
1054
+ size = max (self .fields ["ALT" ].vcf_field .summary .max_number + 1 , 2 )
1055
+ ),
1056
+ "filters" : vcz .VcfZarrDimension (size = self .metadata .num_filters ),
1057
+ }
1058
+
1045
1059
schema_instance = vcz .VcfZarrSchema (
1046
1060
format_version = vcz .ZARR_SCHEMA_FORMAT_VERSION ,
1047
- samples_chunk_size = samples_chunk_size ,
1048
- variants_chunk_size = variants_chunk_size ,
1061
+ dimensions = dimensions ,
1049
1062
fields = [],
1050
1063
)
1051
1064
1052
1065
logger .info (
1053
1066
"Generating schema with chunks="
1054
- f"{ schema_instance .variants_chunk_size , schema_instance .samples_chunk_size } "
1067
+ f"variants={ dimensions ['variants' ].chunk_size } , "
1068
+ f"samples={ dimensions ['samples' ].chunk_size } "
1055
1069
)
1056
1070
1057
1071
def spec_from_field (field , array_name = None ):
1058
1072
return vcz .ZarrArraySpec .from_field (
1059
1073
field ,
1060
- num_samples = n ,
1061
- num_variants = m ,
1062
- samples_chunk_size = schema_instance .samples_chunk_size ,
1063
- variants_chunk_size = schema_instance .variants_chunk_size ,
1074
+ schema_instance ,
1064
1075
array_name = array_name ,
1065
1076
)
1066
1077
1067
- def fixed_field_spec (
1068
- name ,
1069
- dtype ,
1070
- source = None ,
1071
- shape = (m ,),
1072
- dimensions = ("variants" ,),
1073
- chunks = None ,
1074
- ):
1078
+ def fixed_field_spec (name , dtype , source = None , dimensions = ("variants" ,)):
1075
1079
compressor = (
1076
1080
vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config ()
1077
1081
if dtype == "bool"
@@ -1081,16 +1085,11 @@ def fixed_field_spec(
1081
1085
source = source ,
1082
1086
name = name ,
1083
1087
dtype = dtype ,
1084
- shape = shape ,
1085
1088
description = "" ,
1086
1089
dimensions = dimensions ,
1087
- chunks = chunks or [schema_instance .variants_chunk_size ],
1088
1090
compressor = compressor ,
1089
1091
)
1090
1092
1091
- alt_field = self .fields ["ALT" ]
1092
- max_alleles = alt_field .vcf_field .summary .max_number + 1
1093
-
1094
1093
array_specs = [
1095
1094
fixed_field_spec (
1096
1095
name = "variant_contig" ,
@@ -1099,16 +1098,12 @@ def fixed_field_spec(
1099
1098
fixed_field_spec (
1100
1099
name = "variant_filter" ,
1101
1100
dtype = "bool" ,
1102
- shape = (m , self .metadata .num_filters ),
1103
1101
dimensions = ["variants" , "filters" ],
1104
- chunks = (schema_instance .variants_chunk_size , self .metadata .num_filters ),
1105
1102
),
1106
1103
fixed_field_spec (
1107
1104
name = "variant_allele" ,
1108
1105
dtype = "O" ,
1109
- shape = (m , max_alleles ),
1110
1106
dimensions = ["variants" , "alleles" ],
1111
- chunks = (schema_instance .variants_chunk_size , max_alleles ),
1112
1107
),
1113
1108
fixed_field_spec (
1114
1109
name = "variant_id" ,
@@ -1142,32 +1137,23 @@ def fixed_field_spec(
1142
1137
1143
1138
if gt_field is not None and n > 0 :
1144
1139
ploidy = max (gt_field .summary .max_number - 1 , 1 )
1145
- shape = [m , n ]
1146
- chunks = [
1147
- schema_instance .variants_chunk_size ,
1148
- schema_instance .samples_chunk_size ,
1149
- ]
1150
- dimensions = ["variants" , "samples" ]
1140
+ # Add ploidy dimension only when needed
1141
+ schema_instance .dimensions ["ploidy" ] = vcz .VcfZarrDimension (size = ploidy )
1142
+
1151
1143
array_specs .append (
1152
1144
vcz .ZarrArraySpec (
1153
1145
name = "call_genotype_phased" ,
1154
1146
dtype = "bool" ,
1155
- shape = list (shape ),
1156
- chunks = list (chunks ),
1157
- dimensions = list (dimensions ),
1147
+ dimensions = ["variants" , "samples" ],
1158
1148
description = "" ,
1149
+ compressor = vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config (),
1159
1150
)
1160
1151
)
1161
- shape += [ploidy ]
1162
- chunks += [ploidy ]
1163
- dimensions += ["ploidy" ]
1164
1152
array_specs .append (
1165
1153
vcz .ZarrArraySpec (
1166
1154
name = "call_genotype" ,
1167
1155
dtype = gt_field .smallest_dtype (),
1168
- shape = list (shape ),
1169
- chunks = list (chunks ),
1170
- dimensions = list (dimensions ),
1156
+ dimensions = ["variants" , "samples" , "ploidy" ],
1171
1157
description = "" ,
1172
1158
compressor = vcz .DEFAULT_ZARR_COMPRESSOR_GENOTYPES .get_config (),
1173
1159
)
@@ -1176,16 +1162,14 @@ def fixed_field_spec(
1176
1162
vcz .ZarrArraySpec (
1177
1163
name = "call_genotype_mask" ,
1178
1164
dtype = "bool" ,
1179
- shape = list (shape ),
1180
- chunks = list (chunks ),
1181
- dimensions = list (dimensions ),
1165
+ dimensions = ["variants" , "samples" , "ploidy" ],
1182
1166
description = "" ,
1183
1167
compressor = vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config (),
1184
1168
)
1185
1169
)
1186
1170
1187
1171
if local_alleles :
1188
- array_specs = convert_local_allele_field_types (array_specs )
1172
+ array_specs = convert_local_allele_field_types (array_specs , schema_instance )
1189
1173
1190
1174
schema_instance .fields = array_specs
1191
1175
return schema_instance
0 commit comments