Coverage for python/lsst/dax/apdb/tests/data_factory.py: 23%

48 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-04 02:33 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import numpy 

25import pandas 

26import random 

27from typing import Iterator 

28 

29from lsst.daf.base import DateTime 

30from lsst.sphgeom import LonLat, Region, UnitVector3d 

31from lsst.geom import SpherePoint 

32 

33 

34def _genPointsInRegion(region: Region, count: int) -> Iterator[SpherePoint]: 

35 """Generate bunch of SpherePoints inside given region. 

36 

37 Parameters 

38 ---------- 

39 region : `lsst.sphgeom.Region` 

40 Spherical region. 

41 count : `int` 

42 Number of points to generate. 

43 

44 Notes 

45 ----- 

46 Returned points are random but not necessarily uniformly distributed. 

47 """ 

48 bbox = region.getBoundingBox() 

49 center = bbox.getCenter() 

50 center_lon = center.getLon().asRadians() 

51 center_lat = center.getLat().asRadians() 

52 width = bbox.getWidth().asRadians() 

53 height = bbox.getHeight().asRadians() 

54 while count > 0: 

55 lon = random.uniform(center_lon - width / 2, center_lon + width / 2) 

56 lat = random.uniform(center_lat - height / 2, center_lat + height / 2) 

57 lonlat = LonLat.fromRadians(lon, lat) 

58 uv3d = UnitVector3d(lonlat) 

59 if region.contains(uv3d): 

60 yield SpherePoint(lonlat) 

61 count -= 1 

62 

63 

64def makeObjectCatalog( 

65 region: Region, count: int, visit_time: DateTime, *, start_id: int = 1 

66) -> pandas.DataFrame: 

67 """Make a catalog containing a bunch of DiaObjects inside a region. 

68 

69 Parameters 

70 ---------- 

71 region : `lsst.sphgeom.Region` 

72 Spherical region. 

73 count : `int` 

74 Number of records to generate. 

75 visit_time : `lsst.daf.base.DateTime` 

76 Time of the visit. 

77 start_id : `int` 

78 Starting diaObjectId. 

79 

80 Returns 

81 ------- 

82 catalog : `pandas.DataFrame` 

83 Catalog of DiaObjects records. 

84 

85 Notes 

86 ----- 

87 Returned catalog only contains three columns - ``diaObjectId`, ``ra``, and 

88 ``decl`` (in degrees). 

89 """ 

90 points = list(_genPointsInRegion(region, count)) 

91 # diaObjectId=0 may be used in some code for DiaSource foreign key to mean 

92 # the same as ``None``. 

93 ids = numpy.arange(start_id, len(points) + start_id, dtype=numpy.int64) 

94 ras = numpy.array([sp.getRa().asDegrees() for sp in points], dtype=numpy.float64) 

95 decls = numpy.array([sp.getDec().asDegrees() for sp in points], dtype=numpy.float64) 

96 nDiaSources = numpy.ones(len(points), dtype=numpy.int32) 

97 dt = visit_time.toPython() 

98 df = pandas.DataFrame( 

99 { 

100 "diaObjectId": ids, 

101 "ra": ras, 

102 "decl": decls, 

103 "nDiaSources": nDiaSources, 

104 "lastNonForcedSource": dt, 

105 } 

106 ) 

107 return df 

108 

109 

110def makeSourceCatalog( 

111 objects: pandas.DataFrame, visit_time: DateTime, start_id: int = 0, ccdVisitId: int = 1 

112) -> pandas.DataFrame: 

113 """Make a catalog containing a bunch of DiaSources associated with the 

114 input DiaObjects. 

115 

116 Parameters 

117 ---------- 

118 objects : `pandas.DataFrame` 

119 Catalog of DiaObject records. 

120 visit_time : `lsst.daf.base.DateTime` 

121 Time of the visit. 

122 start_id : `int` 

123 Starting value for ``diaObjectId``. 

124 ccdVisitId : `int` 

125 Value for ``ccdVisitId`` field. 

126 

127 Returns 

128 ------- 

129 catalog : `pandas.DataFrame` 

130 Catalog of DiaSource records. 

131 

132 Notes 

133 ----- 

134 Returned catalog only contains small number of columns needed for tests. 

135 """ 

136 nrows = len(objects) 

137 midPointTai = visit_time.get(system=DateTime.MJD) 

138 df = pandas.DataFrame( 

139 { 

140 "diaSourceId": numpy.arange(start_id, start_id + nrows, dtype=numpy.int64), 

141 "diaObjectId": objects["diaObjectId"], 

142 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

143 "parentDiaSourceId": 0, 

144 "ra": objects["ra"], 

145 "decl": objects["decl"], 

146 "midPointTai": numpy.full(nrows, midPointTai, dtype=numpy.float64), 

147 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

148 } 

149 ) 

150 return df 

151 

152 

153def makeForcedSourceCatalog( 

154 objects: pandas.DataFrame, visit_time: DateTime, ccdVisitId: int = 1 

155) -> pandas.DataFrame: 

156 """Make a catalog containing a bunch of DiaForcedSources associated with 

157 the input DiaObjects. 

158 

159 Parameters 

160 ---------- 

161 objects : `pandas.DataFrame` 

162 Catalog of DiaObject records. 

163 visit_time : `lsst.daf.base.DateTime` 

164 Time of the visit. 

165 ccdVisitId : `int` 

166 Value for ``ccdVisitId`` field. 

167 

168 Returns 

169 ------- 

170 catalog : `pandas.DataFrame` 

171 Catalog of DiaForcedSource records. 

172 

173 Notes 

174 ----- 

175 Returned catalog only contains small number of columns needed for tests. 

176 """ 

177 nrows = len(objects) 

178 midPointTai = visit_time.get(system=DateTime.MJD) 

179 df = pandas.DataFrame( 

180 { 

181 "diaObjectId": objects["diaObjectId"], 

182 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

183 "midPointTai": numpy.full(nrows, midPointTai, dtype=numpy.float64), 

184 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

185 } 

186 ) 

187 return df 

188 

189 

190def makeSSObjectCatalog(count: int, start_id: int = 1, flags: int = 0) -> pandas.DataFrame: 

191 """Make a catalog containing a bunch of SSObjects. 

192 

193 Parameters 

194 ---------- 

195 count : `int` 

196 Number of records to generate. 

197 startID : `int` 

198 Initial SSObject ID. 

199 flags : `int` 

200 Value for ``flags`` column. 

201 

202 Returns 

203 ------- 

204 catalog : `pandas.DataFrame` 

205 Catalog of SSObjects records. 

206 

207 Notes 

208 ----- 

209 Returned catalog only contains three columns - ``ssObjectId`, ``arc``, 

210 and ``flags``. 

211 """ 

212 ids = numpy.arange(start_id, count + start_id, dtype=numpy.int64) 

213 arc = numpy.full(count, 0.001, dtype=numpy.float32) 

214 flags_array = numpy.full(count, flags, dtype=numpy.int64) 

215 df = pandas.DataFrame({"ssObjectId": ids, "arc": arc, "flags": flags_array}) 

216 return df