Skip to content

Commit 774cf29

Browse files
committed
feat: Handle extra mdoc fields and use pydantic alias for FrameDosesAndNumbers
1 parent 23c972d commit 774cf29

3 files changed

Lines changed: 170 additions & 31 deletions

File tree

src/mdocfile/data_models.py

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1+
import logging
12
import pandas as pd
2-
from pydantic import field_validator, BaseModel
3+
from pydantic import field_validator, model_validator, BaseModel, ConfigDict, Field
34
from pathlib import Path, PureWindowsPath
45
from typing import List, Optional, Tuple, Union, Sequence
56

67
from mdocfile.utils import find_section_entries, find_title_entries
78

9+
log = logging.getLogger('mdocfile')
10+
811

912
class MdocGlobalData(BaseModel):
1013
"""Data model for global data in a SerialEM mdoc file.
1114
1215
https://bio3d.colorado.edu/SerialEM/hlp/html/about_formats.htm
1316
"""
17+
model_config = ConfigDict(extra='allow')
18+
1419
DataMode: Optional[int] = None
1520
ImageSize: Optional[Tuple[int, int]] = None
1621
Montage: Optional[bool] = None
@@ -69,6 +74,10 @@ class MdocSectionData(BaseModel):
6974
7075
https://bio3d.colorado.edu/SerialEM/hlp/html/about_formats.htm
7176
"""
77+
model_config = ConfigDict(extra='allow', # keep extra field data
78+
validate_by_name=True) # use our validations for aliased fields
79+
# serialize_by_alias=True) # use the version of the fieldname the file arrived as
80+
7281
# headers
7382
ZValue: Optional[int] = None
7483
MontSection: Optional[int] = None
@@ -111,7 +120,9 @@ class MdocSectionData(BaseModel):
111120
Union[Tuple[float, float], Tuple[float, float, float]]] = None
112121
SubFramePath: Optional[Union[PureWindowsPath, Path]] = None
113122
NumSubFrames: Optional[int] = None
114-
FrameDosesAndNumbers: Optional[Sequence[Tuple[float, int]]] = None
123+
FrameDosesAndNumbers: Optional[Sequence[Tuple[float, int]]] = Field(
124+
default=None, validation_alias='FrameDosesAndNumber'
125+
)
115126
DateTime: Optional[str] = None
116127
NavigatorLabel: Optional[str] = None
117128
FilterSlitAndLoss: Optional[Tuple[float, float]] = None
@@ -120,6 +131,16 @@ class MdocSectionData(BaseModel):
120131
CameraPixelSize: Optional[float] = None
121132
Voltage: Optional[float] = None
122133

134+
@model_validator(mode='before')
135+
@classmethod
136+
def warn_on_aliases(cls, data):
137+
if isinstance(data, dict):
138+
for field_name, field_info in cls.model_fields.items():
139+
alias = field_info.validation_alias
140+
if alias and alias in data:
141+
log.warning(f"'{alias}' mapped to '{field_name}'")
142+
return data
143+
123144
@field_validator(
124145
'PieceCoordinates',
125146
'SuperMontCoords',
@@ -133,7 +154,6 @@ class MdocSectionData(BaseModel):
133154
'StageOffsets',
134155
'AlignedPieceCoords',
135156
'AlignedPieceCoordsVS',
136-
'FrameDosesAndNumbers',
137157
'FilterSlitAndLoss',
138158
'MultiShotHoleAndPosition',
139159
mode="before")
@@ -143,28 +163,31 @@ def multi_number_string_to_tuple(cls, value: str):
143163
value = tuple(value.split())
144164
return value
145165

166+
@field_validator('FrameDosesAndNumbers', mode="before")
167+
@classmethod
168+
def parse_frame_doses_and_numbers(cls, value: str):
169+
"""Parse 'dose1 num1 dose2 num2 ...' into [(dose1, num1), ...]"""
170+
if isinstance(value, str):
171+
parts = value.split()
172+
return [(float(parts[i]), int(parts[i+1])) for i in range(0, len(parts)-1, 2)]
173+
return value
174+
146175
@classmethod
147176
def from_lines(cls, lines: List[str]):
148-
lines = [line.strip('[]')
149-
for line
150-
in lines
151-
if len(line) > 0]
152-
key_value_pairs = [line.split('=') for line in lines]
153-
key_value_pairs = [
154-
(k.strip(), v.strip())
155-
for k, v
156-
in key_value_pairs
157-
]
158-
lines = {k: v for k, v in key_value_pairs}
159-
return cls(**lines)
177+
data = {}
178+
for line in lines:
179+
line = line.strip().strip('[]')
180+
if not line or '=' not in line:
181+
continue
182+
k, v = line.split('=', 1)
183+
data[k.strip()] = v.strip()
184+
return cls(**data)
160185

161186
@classmethod
162187
def from_dataframe(cls, series: pd.Series):
163-
section = {}
164-
for k in cls.model_fields.keys():
165-
if k in series.index.tolist():
166-
section[k] = series[k]
167-
return cls(**section)
188+
skip = set(MdocGlobalData.model_fields.keys()) | {'titles'}
189+
data = {k: series[k] for k in series.index if k not in skip}
190+
return cls(**data)
168191

169192
def to_string(self):
170193
data = self.model_dump()
@@ -173,6 +196,8 @@ def to_string(self):
173196
for k, v in data.items():
174197
if v is None:
175198
continue
199+
elif k == 'FrameDosesAndNumbers' and isinstance(v, list):
200+
v = ' '.join(f'{d} {n}' for d, n in v)
176201
elif isinstance(v, tuple):
177202
v = ' '.join(str(el) for el in v)
178203
elif v == 'nan':
@@ -213,19 +238,35 @@ def from_lines(cls, file_lines: List[str]) -> 'Mdoc':
213238
for start_idx, end_idx
214239
in zip(split_idxs, split_idxs[1:])
215240
]
241+
242+
# Warn about extra fields
243+
extra_fields = set(global_data.model_extra.keys())
244+
for s in section_data:
245+
extra_fields.update(s.model_extra.keys())
246+
if extra_fields:
247+
log.warning(f"Unknown fields will be preserved: {extra_fields}")
248+
216249
return cls(titles=titles, global_data=global_data, section_data=section_data)
217250

218251
def to_dataframe(self) -> pd.DataFrame:
219252
"""
220253
Convert an Mdoc object to a pandas DataFrame
221254
"""
222255
global_data = self.global_data.model_dump()
223-
section_data = {
224-
k: [section.model_dump()[k] for section in self.section_data]
225-
for k
226-
in self.section_data[0].model_dump().keys()
227-
}
228-
df = pd.DataFrame(data=section_data)
256+
# Include extra fields from global_data
257+
global_data.update(self.global_data.model_extra)
258+
259+
# Collect all keys from all sections (including extras)
260+
all_keys = set()
261+
section_dicts = []
262+
for section in self.section_data:
263+
d = section.model_dump()
264+
d.update(section.model_extra)
265+
section_dicts.append(d)
266+
all_keys.update(d.keys())
267+
268+
# Build section_data dict with None for missing keys
269+
df = pd.DataFrame(data=dict((k, [d.get(k) for d in section_dicts]) for k in all_keys))
229270

230271
# add duplicate copies of global data and mdoc file titles to each row of
231272
# the dataframe - tidy data is easier to analyse

tests/test_data_models.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test_to_string_is_valid_mdoc(tilt_series_mdoc_file):
8383
mdoc = Mdoc.from_file(tilt_series_mdoc_file)
8484
with NamedTemporaryFile() as tmp:
8585
tmp.write(mdoc.to_string().encode())
86+
tmp.flush()
8687
mdoc2 = Mdoc.from_file(tmp.name)
8788
mdoc_dict = mdoc.section_data[0].model_dump()
8889
mdoc2_dict = mdoc2.section_data[0].model_dump()
@@ -91,4 +92,96 @@ def test_to_string_is_valid_mdoc(tilt_series_mdoc_file):
9192
assert k1 == k2
9293

9394
def test_section_data_from_path():
94-
section = MdocSectionData(SubFramePath=Path('bla.tif'))
95+
some_path = Path('bla.tif')
96+
section = MdocSectionData(SubFramePath=some_path)
97+
assert section.SubFramePath == some_path
98+
assert f'SubFramePath = {some_path}' in section.to_string()
99+
100+
def test_fieldname_alias_mapping():
101+
"""Test that aliased field names are mapped to canonical names."""
102+
lines = """[ZValue = 0]
103+
TiltAngle = 5.0
104+
FrameDosesAndNumber = 2.5 10 3.0 20
105+
""".split('\n')
106+
107+
section = MdocSectionData.from_lines(lines)
108+
109+
# Should be accessible via canonical name
110+
assert section.FrameDosesAndNumbers is not None
111+
assert section.FrameDosesAndNumbers == [(2.5, 10), (3.0, 20)]
112+
113+
# to_string should output the canonical name (FrameDosesAndNumbers)
114+
output = section.to_string()
115+
assert 'FrameDosesAndNumbers = 2.5 10 3.0 20' in output
116+
# Original aliased name should not appear
117+
assert 'FrameDosesAndNumber =' not in output
118+
119+
120+
def test_extra_fields_round_trip():
121+
"""Test that unknown fields are preserved through full Mdoc round-trip."""
122+
mdoc_str = """DataMode = 1
123+
ImageFile = test.mrc
124+
125+
[ZValue = 0]
126+
TiltAngle = 5.0
127+
CountsPerElectron = 42.0
128+
UnknownCustomField = some_value
129+
"""
130+
131+
mdoc = Mdoc.from_string(mdoc_str)
132+
133+
# Extra fields stored in model_extra
134+
assert mdoc.section_data[0].model_extra['CountsPerElectron'] == '42.0'
135+
assert mdoc.section_data[0].model_extra['UnknownCustomField'] == 'some_value'
136+
137+
# Round-trip preserves extra fields
138+
mdoc2 = Mdoc.from_string(mdoc.to_string())
139+
assert mdoc2.section_data[0].model_extra['CountsPerElectron'] == '42.0'
140+
assert mdoc2.section_data[0].model_extra['UnknownCustomField'] == 'some_value'
141+
142+
143+
def test_dataframe_alias_mapping():
144+
"""Test that aliased field names work through dataframe round-trip."""
145+
mdoc_str = """DataMode = 1
146+
ImageFile = test.mrc
147+
148+
[ZValue = 0]
149+
TiltAngle = 5.0
150+
FrameDosesAndNumber = 2.5 10 3.0 20
151+
"""
152+
153+
mdoc = Mdoc.from_string(mdoc_str)
154+
df = mdoc.to_dataframe()
155+
156+
# Dataframe should have canonical name
157+
assert 'FrameDosesAndNumbers' in df.columns
158+
159+
# Round-trip through dataframe
160+
mdoc2 = Mdoc.from_dataframe(df)
161+
assert mdoc2.section_data[0].FrameDosesAndNumbers == [(2.5, 10), (3.0, 20)]
162+
163+
164+
def test_dataframe_extra_fields_round_trip():
165+
"""Test that extra fields survive dataframe round-trip."""
166+
mdoc_str = """DataMode = 1
167+
ImageFile = test.mrc
168+
169+
[ZValue = 0]
170+
TiltAngle = 5.0
171+
CountsPerElectron = 42.0
172+
UnknownCustomField = some_value
173+
"""
174+
175+
mdoc = Mdoc.from_string(mdoc_str)
176+
df = mdoc.to_dataframe()
177+
178+
# Extra fields should be columns in dataframe
179+
assert 'CountsPerElectron' in df.columns
180+
assert 'UnknownCustomField' in df.columns
181+
assert df['CountsPerElectron'].iloc[0] == '42.0'
182+
assert df['UnknownCustomField'].iloc[0] == 'some_value'
183+
184+
# Round-trip through dataframe preserves extra fields
185+
mdoc2 = Mdoc.from_dataframe(df)
186+
assert mdoc2.section_data[0].model_extra['CountsPerElectron'] == '42.0'
187+
assert mdoc2.section_data[0].model_extra['UnknownCustomField'] == 'some_value'

tests/test_functions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,30 @@ def test_read_tilt_series_mdoc_string(tilt_series_mdoc_string):
2222
def test_read_montage_section_mdoc(montage_section_mdoc_file):
2323
df = read(montage_section_mdoc_file)
2424
assert isinstance(df, pd.DataFrame)
25-
assert df.shape == (63, 37)
25+
assert df.shape[0] == 63 # row count
26+
assert df.shape[1] >= 37 # at least this many columns (extra fields preserved)
27+
assert 'TiltAngle' in df.columns
2628

2729

2830
def test_read_montage_section_multiple_mdoc(montage_section_multiple_mdoc_file):
2931
df = read(montage_section_multiple_mdoc_file)
3032
assert isinstance(df, pd.DataFrame)
31-
assert df.shape == (100, 36)
33+
assert df.shape[0] == 100 # row count
34+
assert df.shape[1] >= 36 # at least this many columns (extra fields preserved)
3235

3336

3437
def test_read_frame_set_single_mdoc(frame_set_single_mdoc_file):
3538
df = read(frame_set_single_mdoc_file)
3639
assert isinstance(df, pd.DataFrame)
37-
assert df.shape == (1, 26)
40+
assert df.shape[0] == 1 # row count
41+
assert df.shape[1] >= 26 # at least this many columns (extra fields preserved)
3842

3943

4044
def test_read_frame_set_multiple_mdoc(frame_set_multiple_mdoc_file):
4145
df = read(frame_set_multiple_mdoc_file)
4246
assert isinstance(df, pd.DataFrame)
43-
assert df.shape == (21, 28)
47+
assert df.shape[0] == 21 # row count
48+
assert df.shape[1] >= 28 # at least this many columns (extra fields preserved)
4449

4550

4651
def test_write_tilt_series_mdoc(tilt_series_mdoc_file):

0 commit comments

Comments
 (0)