Coverage for python/lsst/dax/apdb/tests/data_factory.py: 23%

49 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 11:13 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import random 

25from typing import Any, Iterator 

26 

27import numpy 

28import pandas 

29from lsst.daf.base import DateTime 

30from lsst.geom import SpherePoint 

31from lsst.sphgeom import LonLat, Region, UnitVector3d 

32 

33 

34def _genPointsInRegion(region: Region, count: int) -> Iterator[SpherePoint]: 

35 """Generate bunch of SpherePoints inside given region. 

36 

37 Parameters 

38 ---------- 

39 region : `lsst.sphgeom.Region` 

40 Spherical region. 

41 count : `int` 

42 Number of points to generate. 

43 

44 Notes 

45 ----- 

46 Returned points are random but not necessarily uniformly distributed. 

47 """ 

48 bbox = region.getBoundingBox() 

49 center = bbox.getCenter() 

50 center_lon = center.getLon().asRadians() 

51 center_lat = center.getLat().asRadians() 

52 width = bbox.getWidth().asRadians() 

53 height = bbox.getHeight().asRadians() 

54 while count > 0: 

55 lon = random.uniform(center_lon - width / 2, center_lon + width / 2) 

56 lat = random.uniform(center_lat - height / 2, center_lat + height / 2) 

57 lonlat = LonLat.fromRadians(lon, lat) 

58 uv3d = UnitVector3d(lonlat) 

59 if region.contains(uv3d): 

60 yield SpherePoint(lonlat) 

61 count -= 1 

62 

63 

64def makeObjectCatalog( 

65 region: Region, count: int, visit_time: DateTime, *, start_id: int = 1, **kwargs: Any 

66) -> pandas.DataFrame: 

67 """Make a catalog containing a bunch of DiaObjects inside a region. 

68 

69 Parameters 

70 ---------- 

71 region : `lsst.sphgeom.Region` 

72 Spherical region. 

73 count : `int` 

74 Number of records to generate. 

75 visit_time : `lsst.daf.base.DateTime` 

76 Time of the visit. 

77 start_id : `int` 

78 Starting diaObjectId. 

79 **kwargs : `Any` 

80 Additional columns and their values to add to catalog. 

81 

82 Returns 

83 ------- 

84 catalog : `pandas.DataFrame` 

85 Catalog of DiaObjects records. 

86 

87 Notes 

88 ----- 

89 Returned catalog only contains three columns - ``diaObjectId`, ``ra``, and 

90 ``dec`` (in degrees). 

91 """ 

92 points = list(_genPointsInRegion(region, count)) 

93 # diaObjectId=0 may be used in some code for DiaSource foreign key to mean 

94 # the same as ``None``. 

95 ids = numpy.arange(start_id, len(points) + start_id, dtype=numpy.int64) 

96 ras = numpy.array([sp.getRa().asDegrees() for sp in points], dtype=numpy.float64) 

97 decs = numpy.array([sp.getDec().asDegrees() for sp in points], dtype=numpy.float64) 

98 nDiaSources = numpy.ones(len(points), dtype=numpy.int32) 

99 dt = visit_time.toPython() 

100 data = dict( 

101 kwargs, 

102 diaObjectId=ids, 

103 ra=ras, 

104 dec=decs, 

105 nDiaSources=nDiaSources, 

106 lastNonForcedSource=dt, 

107 ) 

108 df = pandas.DataFrame(data) 

109 return df 

110 

111 

112def makeSourceCatalog( 

113 objects: pandas.DataFrame, visit_time: DateTime, start_id: int = 0, ccdVisitId: int = 1 

114) -> pandas.DataFrame: 

115 """Make a catalog containing a bunch of DiaSources associated with the 

116 input DiaObjects. 

117 

118 Parameters 

119 ---------- 

120 objects : `pandas.DataFrame` 

121 Catalog of DiaObject records. 

122 visit_time : `lsst.daf.base.DateTime` 

123 Time of the visit. 

124 start_id : `int` 

125 Starting value for ``diaObjectId``. 

126 ccdVisitId : `int` 

127 Value for ``ccdVisitId`` field. 

128 

129 Returns 

130 ------- 

131 catalog : `pandas.DataFrame` 

132 Catalog of DiaSource records. 

133 

134 Notes 

135 ----- 

136 Returned catalog only contains small number of columns needed for tests. 

137 """ 

138 nrows = len(objects) 

139 midpointMjdTai = visit_time.get(system=DateTime.MJD) 

140 df = pandas.DataFrame( 

141 { 

142 "diaSourceId": numpy.arange(start_id, start_id + nrows, dtype=numpy.int64), 

143 "diaObjectId": objects["diaObjectId"], 

144 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

145 "parentDiaSourceId": 0, 

146 "ra": objects["ra"], 

147 "dec": objects["dec"], 

148 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

149 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

150 } 

151 ) 

152 return df 

153 

154 

155def makeForcedSourceCatalog( 

156 objects: pandas.DataFrame, visit_time: DateTime, ccdVisitId: int = 1 

157) -> pandas.DataFrame: 

158 """Make a catalog containing a bunch of DiaForcedSources associated with 

159 the input DiaObjects. 

160 

161 Parameters 

162 ---------- 

163 objects : `pandas.DataFrame` 

164 Catalog of DiaObject records. 

165 visit_time : `lsst.daf.base.DateTime` 

166 Time of the visit. 

167 ccdVisitId : `int` 

168 Value for ``ccdVisitId`` field. 

169 

170 Returns 

171 ------- 

172 catalog : `pandas.DataFrame` 

173 Catalog of DiaForcedSource records. 

174 

175 Notes 

176 ----- 

177 Returned catalog only contains small number of columns needed for tests. 

178 """ 

179 nrows = len(objects) 

180 midpointMjdTai = visit_time.get(system=DateTime.MJD) 

181 df = pandas.DataFrame( 

182 { 

183 "diaObjectId": objects["diaObjectId"], 

184 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

185 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

186 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

187 } 

188 ) 

189 return df 

190 

191 

192def makeSSObjectCatalog(count: int, start_id: int = 1, flags: int = 0) -> pandas.DataFrame: 

193 """Make a catalog containing a bunch of SSObjects. 

194 

195 Parameters 

196 ---------- 

197 count : `int` 

198 Number of records to generate. 

199 startID : `int` 

200 Initial SSObject ID. 

201 flags : `int` 

202 Value for ``flags`` column. 

203 

204 Returns 

205 ------- 

206 catalog : `pandas.DataFrame` 

207 Catalog of SSObjects records. 

208 

209 Notes 

210 ----- 

211 Returned catalog only contains three columns - ``ssObjectId`, ``arc``, 

212 and ``flags``. 

213 """ 

214 ids = numpy.arange(start_id, count + start_id, dtype=numpy.int64) 

215 arc = numpy.full(count, 0.001, dtype=numpy.float32) 

216 flags_array = numpy.full(count, flags, dtype=numpy.int64) 

217 df = pandas.DataFrame({"ssObjectId": ids, "arc": arc, "flags": flags_array}) 

218 return df