Coverage for python/lsst/dax/apdb/tests/data_factory.py: 23%

49 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-27 03:01 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import random 

25from collections.abc import Iterator 

26from typing import Any 

27 

28import astropy.time 

29import numpy 

30import pandas 

31from lsst.sphgeom import LonLat, Region, UnitVector3d 

32 

33 

34def _genPointsInRegion(region: Region, count: int) -> Iterator[LonLat]: 

35 """Generate bunch of SpherePoints inside given region. 

36 

37 Parameters 

38 ---------- 

39 region : `lsst.sphgeom.Region` 

40 Spherical region. 

41 count : `int` 

42 Number of points to generate. 

43 

44 Notes 

45 ----- 

46 Returned points are random but not necessarily uniformly distributed. 

47 """ 

48 bbox = region.getBoundingBox() 

49 center = bbox.getCenter() 

50 center_lon = center.getLon().asRadians() 

51 center_lat = center.getLat().asRadians() 

52 width = bbox.getWidth().asRadians() 

53 height = bbox.getHeight().asRadians() 

54 while count > 0: 

55 lon = random.uniform(center_lon - width / 2, center_lon + width / 2) 

56 lat = random.uniform(center_lat - height / 2, center_lat + height / 2) 

57 lonlat = LonLat.fromRadians(lon, lat) 

58 uv3d = UnitVector3d(lonlat) 

59 if region.contains(uv3d): 

60 yield lonlat 

61 count -= 1 

62 

63 

64def makeObjectCatalog( 

65 region: Region, count: int, visit_time: astropy.time.Time, *, start_id: int = 1, **kwargs: Any 

66) -> pandas.DataFrame: 

67 """Make a catalog containing a bunch of DiaObjects inside a region. 

68 

69 Parameters 

70 ---------- 

71 region : `lsst.sphgeom.Region` 

72 Spherical region. 

73 count : `int` 

74 Number of records to generate. 

75 visit_time : `astropy.time.Time` 

76 Time of the visit. 

77 start_id : `int` 

78 Starting diaObjectId. 

79 **kwargs : `Any` 

80 Additional columns and their values to add to catalog. 

81 

82 Returns 

83 ------- 

84 catalog : `pandas.DataFrame` 

85 Catalog of DiaObjects records. 

86 

87 Notes 

88 ----- 

89 Returned catalog only contains three columns - ``diaObjectId`, ``ra``, and 

90 ``dec`` (in degrees). 

91 """ 

92 points = list(_genPointsInRegion(region, count)) 

93 # diaObjectId=0 may be used in some code for DiaSource foreign key to mean 

94 # the same as ``None``. 

95 ids = numpy.arange(start_id, len(points) + start_id, dtype=numpy.int64) 

96 ras = numpy.array([lonlat.getLon().asDegrees() for lonlat in points], dtype=numpy.float64) 

97 decs = numpy.array([lonlat.getLat().asDegrees() for lonlat in points], dtype=numpy.float64) 

98 nDiaSources = numpy.ones(len(points), dtype=numpy.int32) 

99 dt = visit_time.datetime 

100 data = dict( 

101 kwargs, 

102 diaObjectId=ids, 

103 ra=ras, 

104 dec=decs, 

105 nDiaSources=nDiaSources, 

106 lastNonForcedSource=dt, 

107 ) 

108 df = pandas.DataFrame(data) 

109 return df 

110 

111 

112def makeSourceCatalog( 

113 objects: pandas.DataFrame, visit_time: astropy.time.Time, start_id: int = 0, ccdVisitId: int = 1 

114) -> pandas.DataFrame: 

115 """Make a catalog containing a bunch of DiaSources associated with the 

116 input DiaObjects. 

117 

118 Parameters 

119 ---------- 

120 objects : `pandas.DataFrame` 

121 Catalog of DiaObject records. 

122 visit_time : `astropy.time.Time` 

123 Time of the visit. 

124 start_id : `int` 

125 Starting value for ``diaObjectId``. 

126 ccdVisitId : `int` 

127 Value for ``ccdVisitId`` field. 

128 

129 Returns 

130 ------- 

131 catalog : `pandas.DataFrame` 

132 Catalog of DiaSource records. 

133 

134 Notes 

135 ----- 

136 Returned catalog only contains small number of columns needed for tests. 

137 """ 

138 nrows = len(objects) 

139 midpointMjdTai = visit_time.mjd 

140 df = pandas.DataFrame( 

141 { 

142 "diaSourceId": numpy.arange(start_id, start_id + nrows, dtype=numpy.int64), 

143 "diaObjectId": objects["diaObjectId"], 

144 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

145 "parentDiaSourceId": 0, 

146 "ra": objects["ra"], 

147 "dec": objects["dec"], 

148 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

149 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

150 } 

151 ) 

152 return df 

153 

154 

155def makeForcedSourceCatalog( 

156 objects: pandas.DataFrame, visit_time: astropy.time.Time, ccdVisitId: int = 1 

157) -> pandas.DataFrame: 

158 """Make a catalog containing a bunch of DiaForcedSources associated with 

159 the input DiaObjects. 

160 

161 Parameters 

162 ---------- 

163 objects : `pandas.DataFrame` 

164 Catalog of DiaObject records. 

165 visit_time : `astropy.time.Time` 

166 Time of the visit. 

167 ccdVisitId : `int` 

168 Value for ``ccdVisitId`` field. 

169 

170 Returns 

171 ------- 

172 catalog : `pandas.DataFrame` 

173 Catalog of DiaForcedSource records. 

174 

175 Notes 

176 ----- 

177 Returned catalog only contains small number of columns needed for tests. 

178 """ 

179 nrows = len(objects) 

180 midpointMjdTai = visit_time.mjd 

181 df = pandas.DataFrame( 

182 { 

183 "diaObjectId": objects["diaObjectId"], 

184 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64), 

185 "midpointMjdTai": numpy.full(nrows, midpointMjdTai, dtype=numpy.float64), 

186 "flags": numpy.full(nrows, 0, dtype=numpy.int64), 

187 } 

188 ) 

189 return df 

190 

191 

192def makeSSObjectCatalog(count: int, start_id: int = 1, flags: int = 0) -> pandas.DataFrame: 

193 """Make a catalog containing a bunch of SSObjects. 

194 

195 Parameters 

196 ---------- 

197 count : `int` 

198 Number of records to generate. 

199 startID : `int` 

200 Initial SSObject ID. 

201 flags : `int` 

202 Value for ``flags`` column. 

203 

204 Returns 

205 ------- 

206 catalog : `pandas.DataFrame` 

207 Catalog of SSObjects records. 

208 

209 Notes 

210 ----- 

211 Returned catalog only contains three columns - ``ssObjectId`, ``arc``, 

212 and ``flags``. 

213 """ 

214 ids = numpy.arange(start_id, count + start_id, dtype=numpy.int64) 

215 arc = numpy.full(count, 0.001, dtype=numpy.float32) 

216 flags_array = numpy.full(count, flags, dtype=numpy.int64) 

217 df = pandas.DataFrame({"ssObjectId": ids, "arc": arc, "flags": flags_array}) 

218 return df