Skip to content

Commit c263404

Browse files
authored
Pass bit depth as argument if missing from meta (#383)
* Move function to work out dtype to common tools * Add bit depth option to parser * Make bit_depth an input in I19 too * Fix test
1 parent 63a6345 commit c263404

7 files changed

Lines changed: 57 additions & 18 deletions

File tree

src/nexgen/beamlines/I19_2_nxs.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from typing import Any, NamedTuple, Optional
1212

1313
import h5py
14-
import numpy as np
1514
from numpy.typing import ArrayLike
1615
from pydantic import field_validator
1716

17+
from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth
1818
from nexgen.utils import get_iso_timestamp
1919

2020
from .. import log
@@ -230,6 +230,7 @@ def eiger_writer(
230230
vds_offset: int = 0,
231231
notes: dict[str, Any] | None = None,
232232
data_entry_key: str = "data",
233+
bit_depth: int = 32,
233234
):
234235
"""
235236
A function to call the NXmx nexus file writer for Eiger 2X 4M detector.
@@ -251,6 +252,8 @@ def eiger_writer(
251252
dataset name and value its data. Defaults to None.
252253
data_entry_key (str, optional): Dataset entry key in datafiles. eg. for gating mode it's data1.\
253254
Defaults to data.
255+
bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \
256+
Defaults to 32.
254257
255258
Raises:
256259
ValueError: If use_meta is set to False but axes_pos and det_pos haven't been passed.
@@ -338,7 +341,7 @@ def eiger_writer(
338341
logger.info(
339342
"Not using meta file to update metadata, only the external links will be set up."
340343
)
341-
vds_dtype = np.uint32
344+
vds_dtype = define_vds_dtype_from_bit_depth(bit_depth)
342345
# Update axes
343346
# Goniometer
344347
for gax in TR.axes_pos:
@@ -457,6 +460,7 @@ def serial_nexus_writer(
457460
use_meta: bool = False,
458461
vds_offset: int = 0,
459462
n_frames: int | None = None,
463+
bit_depth: int = 32,
460464
notes: dict[str, Any] | None = None,
461465
):
462466
"""Wrapper function to gather all parameters from the beamline and kick off the nexus writer for a \
@@ -473,6 +477,8 @@ def serial_nexus_writer(
473477
n_frames (int | None, optional): Number of images for the nexus file. Only needed if different \
474478
from the tot_num_images in the collection params. If passed, the VDS will only contain the \
475479
number of frames specified here. Defaults to None.
480+
bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \
481+
Defaults to 32.
476482
notes (dict[str, Any] | None, optional): Any additional information to be written as NXnote, \
477483
passed as a dictionary of (key, value) pairs where key represents the dataset name and \
478484
value its data. Defaults to None.
@@ -511,7 +517,8 @@ def serial_nexus_writer(
511517
use_meta,
512518
n_frames,
513519
vds_offset,
514-
notes,
520+
bit_depth=bit_depth,
521+
notes=notes,
515522
)
516523
case DetectorName.TRISTAN:
517524
tristan_writer(master_file, collection_params, timestamps, notes)
@@ -523,6 +530,7 @@ def nexus_writer(
523530
timestamps: tuple[datetime, datetime] = (None, None),
524531
use_meta: bool = False,
525532
data_entry_key: str = "data",
533+
bit_depth: int = 32,
526534
):
527535
"""Wrapper function to gather all parameters from the beamline and kick off the nexus writer for a \
528536
standard experiment on I19-2.
@@ -536,6 +544,8 @@ def nexus_writer(
536544
all parameters will need to be passed manually. Defaults to False.
537545
data_entry_key (str, optional): Dataset entry key in datafiles. eg. for gating mode it's data1.\
538546
Defaults to data.
547+
bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \
548+
Defaults to 32.
539549
"""
540550
collection_params = CollectionParams(**params)
541551
wdir = master_file.parent
@@ -603,6 +613,7 @@ def nexus_writer(
603613
timestamps,
604614
use_meta,
605615
data_entry_key=data_entry_key,
616+
bit_depth=bit_depth,
606617
)
607618
case DetectorName.TRISTAN:
608619
tristan_writer(

src/nexgen/beamlines/SSX_Eiger_nxs.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from pathlib import Path
99
from typing import Literal, get_args
1010

11-
import numpy as np
12-
from numpy.typing import DTypeLike
11+
from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth
1312

1413
from .. import log
1514
from ..nxs_utils import (
@@ -64,16 +63,6 @@ class SerialParams(GeneralParams):
6463
experiment_type: str
6564

6665

67-
def _define_vds_dtype_from_bit_depth(bit_depth: int) -> DTypeLike:
68-
"""Define dtype of VDS based on the passed bit depth."""
69-
if bit_depth == 32:
70-
return np.uint32
71-
elif bit_depth == 8:
72-
return np.uint8
73-
else:
74-
return np.uint16
75-
76-
7766
def _get_beamline_specific_params(beamline: str) -> tuple[BeamlineAxes, EigerDetector]:
7867
"""Get beamline specific axes and eiger description.
7968
@@ -295,7 +284,7 @@ def ssx_eiger_writer(
295284
bit_depth = 32
296285
else:
297286
bit_depth = ssx_params["bit_depth"]
298-
vds_dtype = _define_vds_dtype_from_bit_depth(bit_depth)
287+
vds_dtype = define_vds_dtype_from_bit_depth(bit_depth)
299288
logger.debug(f"VDS dtype will be {vds_dtype}")
300289

301290
# Define Goniometer axes

src/nexgen/command_line/I19_2_cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def nexgen_writer(args):
135135
args.use_meta,
136136
args.vds_offset,
137137
args.n_frames,
138+
bit_depth=args.bit_depth,
138139
)
139140
else:
140141
nexus_writer(
@@ -143,6 +144,7 @@ def nexgen_writer(args):
143144
(_start, _stop),
144145
args.use_meta,
145146
data_entry_key=args.data_key,
147+
bit_depth=args.bit_depth,
146148
)
147149

148150

@@ -304,6 +306,13 @@ def nexgen_writer(args):
304306
default="data",
305307
help="Data entry key of dataset in raw .h5 file. Defaults to data.",
306308
)
309+
parser_nex.add_argument(
310+
"-bits" "--bit-depth",
311+
type=int,
312+
choices=[8, 16, 32],
313+
default=32,
314+
help="Default bit depth for eiger collections, used to define dtype of vds data. Defaults to 32.",
315+
)
307316
parser_nex.set_defaults(func=nexgen_writer)
308317

309318

src/nexgen/command_line/nexus_generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from nexgen.nxs_write.nxmx_writer import EventNXmxFileWriter, NXmxFileWriter
2929
from nexgen.nxs_write.write_utils import find_number_of_images
3030
from nexgen.tools.data_writer import generate_event_files, generate_image_files
31+
from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth
3132
from nexgen.utils import (
3233
get_filename_template,
3334
get_iso_timestamp,
@@ -143,6 +144,7 @@ def write_nxmx_cli(args):
143144

144145
try:
145146
entry_key = args.data_key if args.data_key else "data"
147+
vds_dtype = define_vds_dtype_from_bit_depth(args.bit_depth)
146148
# Aaaaaaaaaaaand write
147149
if params.det.mode == "images":
148150
writer = NXmxFileWriter(
@@ -156,7 +158,7 @@ def write_nxmx_cli(args):
156158
)
157159
writer.write(image_datafiles=datafiles, data_entry_key=entry_key)
158160
if not args.no_vds:
159-
writer.write_vds(args.vds_offset)
161+
writer.write_vds(args.vds_offset, vds_dtype=vds_dtype)
160162
else:
161163
writer = EventNXmxFileWriter(
162164
master_file,
@@ -390,6 +392,13 @@ def _parse_cli() -> argparse.ArgumentParser:
390392
default="data",
391393
help="Data entry key of dataset in raw .h5 file. Defaults to data.",
392394
)
395+
nxmx_parser.add_argument(
396+
"-bits" "--bit-depth",
397+
type=int,
398+
choices=[8, 16, 32],
399+
default=32,
400+
help="Default bit depth for eiger collections, used to define dtype of vds data. Defaults to 32.",
401+
)
393402
nxmx_parser.set_defaults(func=write_nxmx_cli)
394403
demo_parser = subparsers.add_parser(
395404
"2",

src/nexgen/tools/vds_tools.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121
vds_logger = logging.getLogger("nexgen.VDSWriter")
2222

2323

24+
def define_vds_dtype_from_bit_depth(bit_depth: int) -> DTypeLike:
25+
"""Define dtype of VDS based on the passed bit depth."""
26+
if bit_depth == 32:
27+
return np.uint32
28+
elif bit_depth == 8:
29+
return np.uint8
30+
else:
31+
return np.uint16
32+
33+
2434
@dataclass
2535
class Dataset:
2636
name: str

tests/beamlines/test_i19nxs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def test_serial_nexus_writer_calls_correct_writer_for_eiger(
8484
True,
8585
None,
8686
0,
87-
None,
87+
bit_depth=32,
88+
notes=None,
8889
)
8990

9091

tests/tools/test_VDS_tools.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,23 @@
88
from nexgen.tools.vds_tools import (
99
Dataset,
1010
create_virtual_layout,
11+
define_vds_dtype_from_bit_depth,
1112
find_datasets_in_file,
1213
image_vds_writer,
1314
jungfrau_vds_writer,
1415
split_datasets,
1516
)
1617

1718

19+
@pytest.mark.parametrize(
20+
"bit_depth, expected_dtype", [(8, np.uint8), (16, np.uint16), (32, np.uint32)]
21+
)
22+
def test_vds_dtype_from_input(bit_depth, expected_dtype):
23+
d = define_vds_dtype_from_bit_depth(bit_depth)
24+
25+
assert d == expected_dtype
26+
27+
1828
def test_when_get_frames_and_shape_less_than_1000_then_correct():
1929
sshape = split_datasets(["test1"], (500, 10, 10))
2030
assert sshape == [Dataset("test1", (500, 10, 10), 0, 500)]

0 commit comments

Comments
 (0)