Coverage for python / lsst / dax / apdb / tests / data_factory.py: 17%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:49 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import datetime 

25import random 

26from collections.abc import Iterator 

27from typing import Any 

28 

29import astropy.time 

30import numpy 

31import pandas 

32 

33from lsst.sphgeom import LonLat, Region, UnitVector3d 

34 

35 

36def _genPointsInRegion(region: Region, count: int) -> Iterator[LonLat]: 

37 """Generate bunch of SpherePoints inside given region. 

38 

39 Parameters 

40 ---------- 

41 region : `lsst.sphgeom.Region` 

42 Spherical region. 

43 count : `int` 

44 Number of points to generate. 

45 

46 Notes 

47 ----- 

48 Returned points are random but not necessarily uniformly distributed. 

49 """ 

50 bbox = region.getBoundingBox() 

51 center = bbox.getCenter() 

52 center_lon = center.getLon().asRadians() 

53 center_lat = center.getLat().asRadians() 

54 width = bbox.getWidth().asRadians() 

55 height = bbox.getHeight().asRadians() 

56 while count > 0: 

57 lon = random.uniform(center_lon - width / 2, center_lon + width / 2) 

58 lat = random.uniform(center_lat - height / 2, center_lat + height / 2) 

59 lonlat = LonLat.fromRadians(lon, lat) 

60 uv3d = UnitVector3d(lonlat) 

61 if region.contains(uv3d): 

62 yield lonlat 

63 count -= 1 

64 

65 

66def makeObjectCatalog( 

67 region: Region | LonLat, count: int, visit_time: astropy.time.Time, *, start_id: int = 1, **kwargs: Any 

68) -> pandas.DataFrame: 

69 """Make a catalog containing a bunch of DiaObjects inside a region. 

70 

71 Parameters 

72 ---------- 

73 region : `lsst.sphgeom.Region` or `lsst.sphgeom.LonLat` 

74 Spherical region or spherical coordinate. 

75 count : `int` 

76 Number of records to generate. 

77 visit_time : `astropy.time.Time` 

78 Time of the visit. 

79 start_id : `int` 

80 Starting diaObjectId. 

81 **kwargs : `Any` 

82 Additional columns and their values to add to catalog. 

83 

84 Returns 

85 ------- 

86 catalog : `pandas.DataFrame` 

87 Catalog of DiaObjects records. 

88 

89 Notes 

90 ----- 

91 Returned catalog only contains three columns - ``diaObjectId`, ``ra``, and 

92 ``dec`` (in degrees). 

93 """ 

94 if isinstance(region, Region): 

95 points = list(_genPointsInRegion(region, count)) 

96 else: 

97 points = [region] * count 

98 # diaObjectId=0 may be used in some code for DiaSource foreign key to mean 

99 # the same as ``None``. 

100 ids = numpy.arange(start_id, len(points) + start_id, dtype=numpy.int64) 

101 ras = numpy.array([lonlat.getLon().asDegrees() for lonlat in points], dtype=numpy.float64) 

102 decs = numpy.array([lonlat.getLat().asDegrees() for lonlat in points], dtype=numpy.float64) 

103 nDiaSources = numpy.ones(len(points), dtype=numpy.int32) 

104 data = dict( 

105 kwargs, 

106 diaObjectId=ids, 

107 ra=ras, 

108 dec=decs, 

109 nDiaSources=nDiaSources, 

110 ) 

111 df = pandas.DataFrame(data) 

112 return df 

113 

114 

115def makeTimestampNow(use_mjd: bool, offset_ms: int = 0) -> float | datetime.datetime: 

116 """Return current timestamp in yeither MJD TAI or datetime format. 

117 

118 Parameters 

119 ---------- 

120 use_mjd : `bool` 

121 If True return time as MJD TAI, otherwise as datetime. 

122 offset_ms : `int`, optional 

123 Additional offset in milliseconds to add to current timestamp. 

124 

125 Returns 

126 ------- 

127 timestamp : `float` or `datetime.datetime` 

128 Resulting timestamp. 

129 """ 

130 if use_mjd: 

131 ts = astropy.time.Time.now().tai.mjd 

132 if offset_ms != 0: 

133 ts += offset_ms / (24 * 3600 * 1_000) 

134 return ts 

135 else: 

136 # TODO: Note that for now we use naive datetime for time_processed, to 

137 # have it consistent with ap_association, this is being replaces with 

138 # MJD TAI in the new APDB schema. 

139 dt = datetime.datetime.now() 

140 if offset_ms != 0: 

141 dt += datetime.timedelta(milliseconds=offset_ms) 

142 return dt 

143 

144 

145def _makeTimestampColumn(column: str, use_mjd: bool = True) -> str: 

146 """Return column name before/after schema migration to MJD TAI.""" 

147 if use_mjd: 

148 if column == "time_processed": 

149 return "timeProcessedMjdTai" 

150 elif column == "time_withdrawn": 

151 return "timeWithdrawnMjdTai" 

152 else: 

153 return f"{column}MjdTai" 

154 else: 

155 return column 

156 

157 

158def makeSourceCatalog( 

159 objects: pandas.DataFrame, 

160 visit_time: astropy.time.Time, 

161 start_id: int = 0, 

162 visit: int = 1, 

163 detector: int = 1, 

164 use_mjd: bool = True, 

165) -> pandas.DataFrame: 

166 """Make a catalog containing a bunch of DiaSources associated with the 

167 input DiaObjects. 

168 

169 Parameters 

170 ---------- 

171 objects : `pandas.DataFrame` 

172 Catalog of DiaObject records. 

173 visit_time : `astropy.time.Time` 

174 Time of the visit. 

175 start_id : `int` 

176 Starting value for ``diaObjectId``. 

177 visit, detector : `int` 

178 Value for ``visit`` and ``detector`` fields. 

179 use_mjd : `bool` 

180 If True use MJD TAI for timestamp columns. 

181 

182 Returns 

183 ------- 

184 catalog : `pandas.DataFrame` 

185 Catalog of DiaSource records. 

186 

187 Notes 

188 ----- 

189 Returned catalog only contains small number of columns needed for tests. 

190 """ 

191 nrows = len(objects) 

192 midpointMjdTai = visit_time.mjd 

193 centroid_flag: list[bool | None] = [True] * nrows 

194 if nrows > 1: 

195 centroid_flag[-1] = None 

196 df = pandas.DataFrame( 

197 { 

198 "diaSourceId": numpy.arange(start_id, start_id + nrows, dtype=numpy.int64), 

199 "diaObjectId": objects["diaObjectId"], 

200 "visit": numpy.full(nrows, visit, dtype=numpy.int64), 

201 "detector": numpy.full(nrows, detector, dtype=numpy.int16), 

202 "parentDiaSourceId": 0, 

203 "ra": objects["ra"], 

204 "dec": objects["dec"], 

205 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

206 "centroid_flag": pandas.Series(centroid_flag, dtype="boolean"), 

207 "ssObjectId": pandas.NA, 

208 _makeTimestampColumn("time_processed", use_mjd): makeTimestampNow(use_mjd), 

209 } 

210 ) 

211 return df 

212 

213 

214def makeForcedSourceCatalog( 

215 objects: pandas.DataFrame, 

216 visit_time: astropy.time.Time, 

217 visit: int = 1, 

218 detector: int = 1, 

219 use_mjd: bool = True, 

220) -> pandas.DataFrame: 

221 """Make a catalog containing a bunch of DiaForcedSources associated with 

222 the input DiaObjects. 

223 

224 Parameters 

225 ---------- 

226 objects : `pandas.DataFrame` 

227 Catalog of DiaObject records. 

228 visit_time : `astropy.time.Time` 

229 Time of the visit. 

230 visit, detector : `int` 

231 Value for ``visit`` and ``detector`` fields. 

232 use_mjd : `bool` 

233 If True use MJD TAI for timestamp columns. 

234 

235 Returns 

236 ------- 

237 catalog : `pandas.DataFrame` 

238 Catalog of DiaForcedSource records. 

239 

240 Notes 

241 ----- 

242 Returned catalog only contains small number of columns needed for tests. 

243 """ 

244 nrows = len(objects) 

245 midpointMjdTai = visit_time.mjd 

246 df = pandas.DataFrame( 

247 { 

248 "diaObjectId": objects["diaObjectId"], 

249 "visit": numpy.full(nrows, visit, dtype=numpy.int64), 

250 "detector": numpy.full(nrows, detector, dtype=numpy.int16), 

251 "ra": objects["ra"], 

252 "dec": objects["dec"], 

253 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

254 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

255 _makeTimestampColumn("time_processed", use_mjd): makeTimestampNow(use_mjd), 

256 } 

257 ) 

258 return df