Skip to content

Commit dc09be8

Browse files
authored
Add Result Unraveling Function (#1047)
* initial * black formatting * address comments on pr #1047
1 parent f537e65 commit dc09be8

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

src/kbmod/util_functions.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from astropy.io import fits
77
from astropy.time import Time
88
from itertools import product
9+
import pandas as pd
910

1011
from kbmod.core.image_stack_py import LayeredImagePy
1112

@@ -114,3 +115,121 @@ def load_deccam_layered_image(filename, psf):
114115
)
115116

116117
return img
118+
119+
120+
def get_unique_obstimes(all_obstimes):
121+
"""Get the unique observation times and their indices.
122+
Used to group observations for mosaicking.
123+
124+
Parameters
125+
----------
126+
all_obstimes : `np.ndarray`
127+
The array of observation times.
128+
129+
Returns
130+
-------
131+
unique_obstimes : `np.ndarray`
132+
The unique observation times.
133+
unique_indices : `list`
134+
A list of lists, where each sublist contains the indices of the grouping.
135+
"""
136+
unique_obstimes = np.unique(all_obstimes)
137+
unique_indices = [list(np.where(all_obstimes == time)[0]) for time in unique_obstimes]
138+
return unique_obstimes, unique_indices
139+
140+
141+
def get_magnitude(flux, zero_point):
142+
"""Convert a flux value to a magnitude using the zero point.
143+
144+
Parameters
145+
----------
146+
flux : `float`
147+
The flux value to convert.
148+
zero_point : `float`
149+
The zero point of the observations.
150+
151+
Returns
152+
-------
153+
mag : `float`
154+
The calculated magnitude.
155+
"""
156+
mag = -2.5 * np.log10(flux) + zero_point
157+
return mag
158+
159+
160+
def unravel_results(results, image_collection, obscode="X05", batch_id=None):
161+
"""Take a results file and transform it into a table of individual observations.
162+
163+
Parameters
164+
----------
165+
results : `kbmod.results.Results`
166+
The results.
167+
image_collection : `kbmod.image_collection.ImageCollection`
168+
The image collection containing the images used in the results.
169+
obscode : `str`, optional
170+
The observatory code to use for the observations.
171+
Default: "X05" (LSST).
172+
batch_id : `str`, optional
173+
The batch ID to use for this result set.
174+
individual observation ids will be in the format of
175+
"{batch_id}-{result #}-{observation #}".
176+
177+
Returns
178+
-------
179+
final_df : `pandas.DataFrame`
180+
A DataFrame containing the individual observations with columns:
181+
- id: The unique identifier for the observation.
182+
- ra: The right ascension of the observation in degrees.
183+
- dec: The declination of the observation in degrees.
184+
- magnitude: The magnitude of the observation.
185+
- mjd: The modified Julian date of the observation.
186+
- band: The band of the observation.
187+
- obscode: The observatory code for the observation.
188+
"""
189+
zp = np.mean(image_collection["zeroPoint"])
190+
191+
ids = []
192+
ras = []
193+
decs = []
194+
mags = []
195+
mjds = []
196+
bands = []
197+
obscodes = []
198+
199+
all_times = results.table.meta["mjd_mid"]
200+
all_bands = image_collection["band"]
201+
202+
_, unique_indices = get_unique_obstimes(image_collection["mjd_mid"])
203+
first_of_each_frame = np.array([i[0] for i in unique_indices])
204+
205+
for i, row in enumerate(results):
206+
if "obs_valid" in results.table.colnames:
207+
valid_obs = row["obs_valid"]
208+
else:
209+
valid_obs = np.full(row["obs_count"], True)
210+
num_valid = row["obs_count"]
211+
212+
# need to figure out a better way to do this
213+
if batch_id is not None:
214+
ids.append([f"{batch_id}-{i}-{j}" for j in range(num_valid)])
215+
else:
216+
ids.append([f"{i}-{j}" for j in range(num_valid)])
217+
218+
ras.append(row["img_ra"][valid_obs])
219+
decs.append(row["img_dec"][valid_obs])
220+
221+
mags.append([get_magnitude(row["flux"], zp)] * num_valid)
222+
mjds.append(all_times[valid_obs])
223+
bands.append(all_bands[first_of_each_frame][valid_obs])
224+
obscodes.append([obscode] * num_valid)
225+
226+
final_df = pd.DataFrame()
227+
final_df["id"] = np.concatenate(ids)
228+
final_df["ra"] = np.concatenate(ras)
229+
final_df["dec"] = np.concatenate(decs)
230+
final_df["magnitude"] = np.concatenate(mags)
231+
final_df["mjd"] = np.concatenate(mjds)
232+
final_df["band"] = np.concatenate(bands)
233+
final_df["obscode"] = np.concatenate(obscodes)
234+
235+
return final_df

tests/test_util_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
from astropy.table import Table
12
import numpy as np
23
import unittest
34

45
from pathlib import Path
56

67
from kbmod.core.image_stack_py import LayeredImagePy
78
from kbmod.core.psf import PSF
9+
from kbmod.results import Results
10+
from kbmod.search import Trajectory
811
from kbmod.util_functions import (
912
get_matched_obstimes,
1013
load_deccam_layered_image,
1114
mjd_to_day,
15+
unravel_results,
1216
)
1317

1418

@@ -34,6 +38,48 @@ def test_load_deccam_layered_image(self):
3438
self.assertGreater(img.width, 0)
3539
self.assertGreater(img.height, 0)
3640

41+
def test_unravel_results(self):
42+
trj_list = []
43+
num_images = 10
44+
num_trjs = 11
45+
46+
for i in range(num_trjs):
47+
trj = Trajectory(
48+
x=i,
49+
y=i + 0,
50+
vx=i - 2.0,
51+
vy=i + 5.0,
52+
flux=5.0 * i,
53+
lh=100.0 + i,
54+
obs_count=num_images,
55+
)
56+
trj_list.append(trj)
57+
58+
res = Results.from_trajectories(trj_list)
59+
60+
# create an "image collection" with all required fields
61+
ic = Table()
62+
ic["mjd_mid"] = [60000 + i for i in range(num_images)]
63+
ic["band"] = ["g" for _ in range(num_images)]
64+
ic["zeroPoint"] = [31.4 for _ in range(num_images)]
65+
66+
res.table.meta["mjd_mid"] = ic["mjd_mid"]
67+
obs_count = [num_images for _ in range(num_trjs)]
68+
res.table["obs_count"] = obs_count
69+
res.table["img_ra"] = [np.array([j + (i * 0.1) for j in range(num_images)]) for i in range(num_trjs)]
70+
res.table["img_dec"] = [np.array([j + (i * 0.1) for j in range(num_images)]) for i in range(num_trjs)]
71+
72+
df = unravel_results(res, ic)
73+
self.assertEqual(len(df), (num_images * num_trjs))
74+
75+
obs_count[int(num_images / 2)] = num_images - 1
76+
valids = [[True] * num_images for _ in range(num_trjs)]
77+
valids[int(num_images / 2)][-1] = False # make one observation invalid
78+
res.table["obs_valid"] = valids
79+
res.table["obs_count"] = obs_count
80+
df2 = unravel_results(res, ic)
81+
self.assertEqual(len(df2), (num_images * num_trjs) - 1)
82+
3783

3884
if __name__ == "__main__":
3985
unittest.main()

0 commit comments

Comments
 (0)