-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcryodataset.py
More file actions
81 lines (67 loc) · 3.03 KB
/
cryodataset.py
File metadata and controls
81 lines (67 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
### remake a prediction
import torch
import numpy as np
from torch.utils.data import Dataset
import h5py
class CryoData(Dataset):
def __init__(self, h5_file, representation='backbone'):
"""
Initialize dataset for cryo-EM data.
Args:
h5_file (str): Path to H5 file
representation (str): 'backbone' for Cα atoms or 'allatom' for all atoms
"""
self.h5_file = h5_file
self.representation = representation
with h5py.File(self.h5_file, 'r') as file:
self.keys = list(file.keys())
def __len__(self):
return len(self.keys)
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as file:
key_name = self.keys[idx] # e.g., 'EMDB_1874'
group = file[key_name]
# Extract EMDB ID from key name
emdb_id = key_name.replace('EMDB_', '')
# Get ground truth data
ground_truth_grid = torch.tensor(group['ground_truth_grid'][:])
ground_truth_coords = torch.tensor(group['ground_truth_coords'][:])
em_volume = torch.tensor(group['em_volume'][:])
# Get scale information
scale_norm = torch.tensor(group['scale_norm'][:])
scale_min = torch.tensor(group['scale_min'][:])
# Get metadata from attributes
pdb_id = group.attrs.get('pdb_id', '')
pdb_file = group.attrs.get('pdb_file', '')
em_map_file = group.attrs.get('em_map_file', '')
homolog_types = group.attrs.get('homolog_types', [])
# Initialize result with ground truth data
result = {
'emdb_id': emdb_id,
'pdb_id': pdb_id,
'representation': self.representation,
'ground_truth_grid': ground_truth_grid, # 64³ binary grid
'ground_truth_coords': ground_truth_coords, # Raw coordinates
'em_volume': em_volume, # EM density map (64³)
'scale_norm': scale_norm,
'scale_min': scale_min,
'metadata': {
'pdb_file': pdb_file,
'em_map_file': em_map_file,
'homolog_types': homolog_types
}
}
# Add available homologs as flat keys
for homolog_type in ['perturbed1', 'perturbed2', 'complex']:
homolog_key = f'homolog_{homolog_type}'
if homolog_key in group:
result[homolog_key] = torch.tensor(group[homolog_key][:])
return result
class CryoDataBackbone(CryoData):
"""Dataset for backbone (Cα) representation."""
def __init__(self, h5_file):
super().__init__(h5_file, representation='backbone')
class CryoDataAllAtom(CryoData):
"""Dataset for all-atom representation."""
def __init__(self, h5_file, representation='allatom'):
super().__init__(h5_file, representation='allatom')