Coverage for python / lsst / dax / apdb / tests / data_factory.py: 16%
71 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 10:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 10:35 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24import datetime
25import random
26from collections.abc import Iterator
27from typing import Any
29import astropy.time
30import numpy
31import pandas
33from lsst.sphgeom import LonLat, Region, UnitVector3d
36def _genPointsInRegion(region: Region, count: int) -> Iterator[LonLat]:
37 """Generate bunch of SpherePoints inside given region.
39 Parameters
40 ----------
41 region : `lsst.sphgeom.Region`
42 Spherical region.
43 count : `int`
44 Number of points to generate.
46 Notes
47 -----
48 Returned points are random but not necessarily uniformly distributed.
49 """
50 bbox = region.getBoundingBox()
51 center = bbox.getCenter()
52 center_lon = center.getLon().asRadians()
53 center_lat = center.getLat().asRadians()
54 width = bbox.getWidth().asRadians()
55 height = bbox.getHeight().asRadians()
56 while count > 0:
57 lon = random.uniform(center_lon - width / 2, center_lon + width / 2)
58 lat = random.uniform(center_lat - height / 2, center_lat + height / 2)
59 lonlat = LonLat.fromRadians(lon, lat)
60 uv3d = UnitVector3d(lonlat)
61 if region.contains(uv3d):
62 yield lonlat
63 count -= 1
66def makeObjectCatalog(
67 region: Region | LonLat, count: int, *, start_id: int = 1, **kwargs: Any
68) -> pandas.DataFrame:
69 """Make a catalog containing a bunch of DiaObjects inside a region.
71 Parameters
72 ----------
73 region : `lsst.sphgeom.Region` or `lsst.sphgeom.LonLat`
74 Spherical region or spherical coordinate.
75 count : `int`
76 Number of records to generate.
77 start_id : `int`
78 Starting diaObjectId.
79 **kwargs : `Any`
80 Additional columns and their values to add to catalog.
82 Returns
83 -------
84 catalog : `pandas.DataFrame`
85 Catalog of DiaObjects records.
87 Notes
88 -----
89 Returned catalog only contains three columns - ``diaObjectId`, ``ra``, and
90 ``dec`` (in degrees).
91 """
92 if isinstance(region, Region):
93 points = list(_genPointsInRegion(region, count))
94 else:
95 points = [region] * count
96 # diaObjectId=0 may be used in some code for DiaSource foreign key to mean
97 # the same as ``None``.
98 ids = numpy.arange(start_id, len(points) + start_id, dtype=numpy.int64)
99 ras = numpy.array([lonlat.getLon().asDegrees() for lonlat in points], dtype=numpy.float64)
100 decs = numpy.array([lonlat.getLat().asDegrees() for lonlat in points], dtype=numpy.float64)
101 nDiaSources = numpy.ones(len(points), dtype=numpy.int32)
102 firstDiaSourceMjdTai = numpy.full(len(points), 60000.0, dtype=numpy.float64)
103 data = dict(
104 kwargs,
105 diaObjectId=ids,
106 ra=ras,
107 dec=decs,
108 nDiaSources=nDiaSources,
109 firstDiaSourceMjdTai=firstDiaSourceMjdTai,
110 )
111 df = pandas.DataFrame(data)
112 return df
115def makeTimestamp(time: astropy.time.Time, use_mjd: bool, offset_ms: int = 0) -> float | datetime.datetime:
116 """Return timestamp in either MJD TAI or datetime format.
118 Parameters
119 ----------
120 time : `astropy.time.Time`
121 Time value to convert to timestamp.
122 use_mjd : `bool`
123 If True return time as MJD TAI, otherwise as datetime.
124 offset_ms : `int`, optional
125 Additional offset in milliseconds to add to current timestamp.
127 Returns
128 -------
129 timestamp : `float` or `datetime.datetime`
130 Resulting timestamp.
131 """
132 if use_mjd:
133 ts = time.tai.mjd
134 if offset_ms != 0:
135 ts += offset_ms / (24 * 3600 * 1_000)
136 return ts
137 else:
138 # TODO: Note that for now we use naive datetime for time_processed, to
139 # have it consistent with ap_association, this is being replaces with
140 # MJD TAI in the new APDB schema.
141 dt = time.datetime
142 if offset_ms != 0:
143 dt += datetime.timedelta(milliseconds=offset_ms)
144 return dt
147def makeTimestampColumn(column: str, use_mjd: bool = True) -> str:
148 """Return column name before/after schema migration to MJD TAI."""
149 if use_mjd:
150 if column == "time_processed":
151 return "timeProcessedMjdTai"
152 elif column == "time_withdrawn":
153 return "timeWithdrawnMjdTai"
154 else:
155 return f"{column}MjdTai"
156 else:
157 return column
160def makeSourceCatalog(
161 objects: pandas.DataFrame,
162 visit_time: astropy.time.Time,
163 start_id: int = 0,
164 visit: int = 1,
165 detector: int = 1,
166 *,
167 use_mjd: bool = True,
168 processing_time: astropy.time.Time | None = None,
169) -> pandas.DataFrame:
170 """Make a catalog containing a bunch of DiaSources associated with the
171 input DiaObjects.
173 Parameters
174 ----------
175 objects : `pandas.DataFrame`
176 Catalog of DiaObject records.
177 visit_time : `astropy.time.Time`
178 Time of the visit.
179 start_id : `int`
180 Starting value for ``diaObjectId``.
181 visit, detector : `int`
182 Value for ``visit`` and ``detector`` fields.
183 use_mjd : `bool`
184 If True use MJD TAI for timestamp columns.
185 processing_time : `astropy.time.Time` or `None`
186 Processing time, if `None` the value of ``visit_time`` is used.
188 Returns
189 -------
190 catalog : `pandas.DataFrame`
191 Catalog of DiaSource records.
193 Notes
194 -----
195 Returned catalog only contains small number of columns needed for tests.
196 """
197 if processing_time is None:
198 processing_time = visit_time
199 nrows = len(objects)
200 midpointMjdTai = visit_time.tai.mjd
201 centroid_flag: list[bool | None] = [True] * nrows
202 if nrows > 1:
203 centroid_flag[-1] = None
204 df = pandas.DataFrame(
205 {
206 "diaSourceId": numpy.arange(start_id, start_id + nrows, dtype=numpy.int64),
207 "diaObjectId": pandas.Series(objects["diaObjectId"], dtype="Int64"),
208 "visit": numpy.full(nrows, visit, dtype=numpy.int64),
209 "detector": numpy.full(nrows, detector, dtype=numpy.int16),
210 "parentDiaSourceId": 0,
211 "ra": objects["ra"],
212 "dec": objects["dec"],
213 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64),
214 "centroid_flag": pandas.Series(centroid_flag, dtype="boolean"),
215 "ssObjectId": pandas.NA,
216 makeTimestampColumn("time_processed", use_mjd): makeTimestamp(processing_time, use_mjd),
217 }
218 )
219 return df
222def makeForcedSourceCatalog(
223 objects: pandas.DataFrame,
224 visit_time: astropy.time.Time,
225 visit: int = 1,
226 detector: int = 1,
227 *,
228 use_mjd: bool = True,
229 processing_time: astropy.time.Time | None = None,
230) -> pandas.DataFrame:
231 """Make a catalog containing a bunch of DiaForcedSources associated with
232 the input DiaObjects.
234 Parameters
235 ----------
236 objects : `pandas.DataFrame`
237 Catalog of DiaObject records.
238 visit_time : `astropy.time.Time`
239 Time of the visit.
240 visit, detector : `int`
241 Value for ``visit`` and ``detector`` fields.
242 use_mjd : `bool`
243 If True use MJD TAI for timestamp columns.
244 processing_time : `astropy.time.Time` or `None`
245 Processing time, if `None` the value of ``visit_time`` is used.
247 Returns
248 -------
249 catalog : `pandas.DataFrame`
250 Catalog of DiaForcedSource records.
252 Notes
253 -----
254 Returned catalog only contains small number of columns needed for tests.
255 """
256 if processing_time is None:
257 processing_time = visit_time
258 nrows = len(objects)
259 midpointMjdTai = visit_time.mjd
260 df = pandas.DataFrame(
261 {
262 "diaObjectId": objects["diaObjectId"],
263 "visit": numpy.full(nrows, visit, dtype=numpy.int64),
264 "detector": numpy.full(nrows, detector, dtype=numpy.int16),
265 "ra": objects["ra"],
266 "dec": objects["dec"],
267 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64),
268 "flags": numpy.full(nrows, 0, dtype=numpy.int64),
269 makeTimestampColumn("time_processed", use_mjd): makeTimestamp(processing_time, use_mjd),
270 }
271 )
272 return df