Coverage for python/lsst/dax/apdb/tests/data_factory.py: 24%

50 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-12 10:17 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import random 

25from collections.abc import Iterator 

26from typing import Any 

27 

28import numpy 

29import pandas 

30from lsst.daf.base import DateTime 

31from lsst.geom import SpherePoint 

32from lsst.sphgeom import LonLat, Region, UnitVector3d 

33 

34 

35def _genPointsInRegion(region: Region, count: int) -> Iterator[SpherePoint]: 

36 """Generate bunch of SpherePoints inside given region. 

37 

38 Parameters 

39 ---------- 

40 region : `lsst.sphgeom.Region` 

41 Spherical region. 

42 count : `int` 

43 Number of points to generate. 

44 

45 Notes 

46 ----- 

47 Returned points are random but not necessarily uniformly distributed. 

48 """ 

49 bbox = region.getBoundingBox() 

50 center = bbox.getCenter() 

51 center_lon = center.getLon().asRadians() 

52 center_lat = center.getLat().asRadians() 

53 width = bbox.getWidth().asRadians() 

54 height = bbox.getHeight().asRadians() 

55 while count > 0: 

56 lon = random.uniform(center_lon - width / 2, center_lon + width / 2) 

57 lat = random.uniform(center_lat - height / 2, center_lat + height / 2) 

58 lonlat = LonLat.fromRadians(lon, lat) 

59 uv3d = UnitVector3d(lonlat) 

60 if region.contains(uv3d): 

61 yield SpherePoint(lonlat) 

62 count -= 1 

63 

64 

65def makeObjectCatalog( 

66 region: Region, count: int, visit_time: DateTime, *, start_id: int = 1, **kwargs: Any 

67) -> pandas.DataFrame: 

68 """Make a catalog containing a bunch of DiaObjects inside a region. 

69 

70 Parameters 

71 ---------- 

72 region : `lsst.sphgeom.Region` 

73 Spherical region. 

74 count : `int` 

75 Number of records to generate. 

76 visit_time : `lsst.daf.base.DateTime` 

77 Time of the visit. 

78 start_id : `int` 

79 Starting diaObjectId. 

80 **kwargs : `Any` 

81 Additional columns and their values to add to catalog. 

82 

83 Returns 

84 ------- 

85 catalog : `pandas.DataFrame` 

86 Catalog of DiaObjects records. 

87 

88 Notes 

89 ----- 

90 Returned catalog only contains three columns - ``diaObjectId`, ``ra``, and 

91 ``dec`` (in degrees). 

92 """ 

93 points = list(_genPointsInRegion(region, count)) 

94 # diaObjectId=0 may be used in some code for DiaSource foreign key to mean 

95 # the same as ``None``. 

96 ids = numpy.arange(start_id, len(points) + start_id, dtype=numpy.int64) 

97 ras = numpy.array([sp.getRa().asDegrees() for sp in points], dtype=numpy.float64) 

98 decs = numpy.array([sp.getDec().asDegrees() for sp in points], dtype=numpy.float64) 

99 nDiaSources = numpy.ones(len(points), dtype=numpy.int32) 

100 dt = visit_time.toPython() 

101 data = dict( 

102 kwargs, 

103 diaObjectId=ids, 

104 ra=ras, 

105 dec=decs, 

106 nDiaSources=nDiaSources, 

107 lastNonForcedSource=dt, 

108 ) 

109 df = pandas.DataFrame(data) 

110 return df 

111 

112 

113def makeSourceCatalog( 

114 objects: pandas.DataFrame, visit_time: DateTime, start_id: int = 0, ccdVisitId: int = 1 

115) -> pandas.DataFrame: 

116 """Make a catalog containing a bunch of DiaSources associated with the 

117 input DiaObjects. 

118 

119 Parameters 

120 ---------- 

121 objects : `pandas.DataFrame` 

122 Catalog of DiaObject records. 

123 visit_time : `lsst.daf.base.DateTime` 

124 Time of the visit. 

125 start_id : `int` 

126 Starting value for ``diaObjectId``. 

127 ccdVisitId : `int` 

128 Value for ``ccdVisitId`` field. 

129 

130 Returns 

131 ------- 

132 catalog : `pandas.DataFrame` 

133 Catalog of DiaSource records. 

134 

135 Notes 

136 ----- 

137 Returned catalog only contains small number of columns needed for tests. 

138 """ 

139 nrows = len(objects) 

140 midpointMjdTai = visit_time.get(system=DateTime.MJD) 

141 df = pandas.DataFrame( 

142 { 

143 "diaSourceId": numpy.arange(start_id, start_id + nrows, dtype=numpy.int64), 

144 "diaObjectId": objects["diaObjectId"], 

145 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

146 "parentDiaSourceId": 0, 

147 "ra": objects["ra"], 

148 "dec": objects["dec"], 

149 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

150 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

151 } 

152 ) 

153 return df 

154 

155 

156def makeForcedSourceCatalog( 

157 objects: pandas.DataFrame, visit_time: DateTime, ccdVisitId: int = 1 

158) -> pandas.DataFrame: 

159 """Make a catalog containing a bunch of DiaForcedSources associated with 

160 the input DiaObjects. 

161 

162 Parameters 

163 ---------- 

164 objects : `pandas.DataFrame` 

165 Catalog of DiaObject records. 

166 visit_time : `lsst.daf.base.DateTime` 

167 Time of the visit. 

168 ccdVisitId : `int` 

169 Value for ``ccdVisitId`` field. 

170 

171 Returns 

172 ------- 

173 catalog : `pandas.DataFrame` 

174 Catalog of DiaForcedSource records. 

175 

176 Notes 

177 ----- 

178 Returned catalog only contains small number of columns needed for tests. 

179 """ 

180 nrows = len(objects) 

181 midpointMjdTai = visit_time.get(system=DateTime.MJD) 

182 df = pandas.DataFrame( 

183 { 

184 "diaObjectId": objects["diaObjectId"], 

185 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

186 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

187 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

188 } 

189 ) 

190 return df 

191 

192 

193def makeSSObjectCatalog(count: int, start_id: int = 1, flags: int = 0) -> pandas.DataFrame: 

194 """Make a catalog containing a bunch of SSObjects. 

195 

196 Parameters 

197 ---------- 

198 count : `int` 

199 Number of records to generate. 

200 startID : `int` 

201 Initial SSObject ID. 

202 flags : `int` 

203 Value for ``flags`` column. 

204 

205 Returns 

206 ------- 

207 catalog : `pandas.DataFrame` 

208 Catalog of SSObjects records. 

209 

210 Notes 

211 ----- 

212 Returned catalog only contains three columns - ``ssObjectId`, ``arc``, 

213 and ``flags``. 

214 """ 

215 ids = numpy.arange(start_id, count + start_id, dtype=numpy.int64) 

216 arc = numpy.full(count, 0.001, dtype=numpy.float32) 

217 flags_array = numpy.full(count, flags, dtype=numpy.int64) 

218 df = pandas.DataFrame({"ssObjectId": ids, "arc": arc, "flags": flags_array}) 

219 return df