Coverage for python/lsst/dax/apdb/tests/data_factory.py: 23%
48 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-25 01:34 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-25 01:34 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24import numpy
25import pandas
26import random
27from typing import Iterator
29from lsst.daf.base import DateTime
30from lsst.sphgeom import LonLat, Region, UnitVector3d
31from lsst.geom import SpherePoint
34def _genPointsInRegion(region: Region, count: int) -> Iterator[SpherePoint]:
35 """Generate bunch of SpherePoints inside given region.
37 Parameters
38 ----------
39 region : `lsst.sphgeom.Region`
40 Spherical region.
41 count : `int`
42 Number of points to generate.
44 Notes
45 -----
46 Returned points are random but not necessarily uniformly distributed.
47 """
48 bbox = region.getBoundingBox()
49 center = bbox.getCenter()
50 center_lon = center.getLon().asRadians()
51 center_lat = center.getLat().asRadians()
52 width = bbox.getWidth().asRadians()
53 height = bbox.getHeight().asRadians()
54 while count > 0:
55 lon = random.uniform(center_lon - width / 2, center_lon + width / 2)
56 lat = random.uniform(center_lat - height / 2, center_lat + height / 2)
57 lonlat = LonLat.fromRadians(lon, lat)
58 uv3d = UnitVector3d(lonlat)
59 if region.contains(uv3d):
60 yield SpherePoint(lonlat)
61 count -= 1
64def makeObjectCatalog(
65 region: Region, count: int, visit_time: DateTime, *, start_id: int = 1
66) -> pandas.DataFrame:
67 """Make a catalog containing a bunch of DiaObjects inside a region.
69 Parameters
70 ----------
71 region : `lsst.sphgeom.Region`
72 Spherical region.
73 count : `int`
74 Number of records to generate.
75 visit_time : `lsst.daf.base.DateTime`
76 Time of the visit.
77 start_id : `int`
78 Starting diaObjectId.
80 Returns
81 -------
82 catalog : `pandas.DataFrame`
83 Catalog of DiaObjects records.
85 Notes
86 -----
87 Returned catalog only contains three columns - ``diaObjectId`, ``ra``, and
88 ``decl`` (in degrees).
89 """
90 points = list(_genPointsInRegion(region, count))
91 # diaObjectId=0 may be used in some code for DiaSource foreign key to mean
92 # the same as ``None``.
93 ids = numpy.arange(start_id, len(points) + start_id, dtype=numpy.int64)
94 ras = numpy.array([sp.getRa().asDegrees() for sp in points], dtype=numpy.float64)
95 decls = numpy.array([sp.getDec().asDegrees() for sp in points], dtype=numpy.float64)
96 nDiaSources = numpy.ones(len(points), dtype=numpy.int32)
97 dt = visit_time.toPython()
98 df = pandas.DataFrame(
99 {
100 "diaObjectId": ids,
101 "ra": ras,
102 "decl": decls,
103 "nDiaSources": nDiaSources,
104 "lastNonForcedSource": dt,
105 }
106 )
107 return df
110def makeSourceCatalog(
111 objects: pandas.DataFrame, visit_time: DateTime, start_id: int = 0, ccdVisitId: int = 1
112) -> pandas.DataFrame:
113 """Make a catalog containing a bunch of DiaSources associated with the
114 input DiaObjects.
116 Parameters
117 ----------
118 objects : `pandas.DataFrame`
119 Catalog of DiaObject records.
120 visit_time : `lsst.daf.base.DateTime`
121 Time of the visit.
122 start_id : `int`
123 Starting value for ``diaObjectId``.
124 ccdVisitId : `int`
125 Value for ``ccdVisitId`` field.
127 Returns
128 -------
129 catalog : `pandas.DataFrame`
130 Catalog of DiaSource records.
132 Notes
133 -----
134 Returned catalog only contains small number of columns needed for tests.
135 """
136 nrows = len(objects)
137 midPointTai = visit_time.get(system=DateTime.MJD)
138 df = pandas.DataFrame(
139 {
140 "diaSourceId": numpy.arange(start_id, start_id + nrows, dtype=numpy.int64),
141 "diaObjectId": objects["diaObjectId"],
142 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64),
143 "parentDiaSourceId": 0,
144 "ra": objects["ra"],
145 "decl": objects["decl"],
146 "midPointTai": numpy.full(nrows, midPointTai, dtype=numpy.float64),
147 "flags": numpy.full(nrows, 0, dtype=numpy.int64),
148 }
149 )
150 return df
153def makeForcedSourceCatalog(
154 objects: pandas.DataFrame, visit_time: DateTime, ccdVisitId: int = 1
155) -> pandas.DataFrame:
156 """Make a catalog containing a bunch of DiaForcedSources associated with
157 the input DiaObjects.
159 Parameters
160 ----------
161 objects : `pandas.DataFrame`
162 Catalog of DiaObject records.
163 visit_time : `lsst.daf.base.DateTime`
164 Time of the visit.
165 ccdVisitId : `int`
166 Value for ``ccdVisitId`` field.
168 Returns
169 -------
170 catalog : `pandas.DataFrame`
171 Catalog of DiaForcedSource records.
173 Notes
174 -----
175 Returned catalog only contains small number of columns needed for tests.
176 """
177 nrows = len(objects)
178 midPointTai = visit_time.get(system=DateTime.MJD)
179 df = pandas.DataFrame(
180 {
181 "diaObjectId": objects["diaObjectId"],
182 "ccdVisitId": numpy.full(nrows, ccdVisitId, dtype=numpy.int64),
183 "midPointTai": numpy.full(nrows, midPointTai, dtype=numpy.float64),
184 "flags": numpy.full(nrows, 0, dtype=numpy.int64),
185 }
186 )
187 return df
190def makeSSObjectCatalog(count: int, start_id: int = 1, flags: int = 0) -> pandas.DataFrame:
191 """Make a catalog containing a bunch of SSObjects.
193 Parameters
194 ----------
195 count : `int`
196 Number of records to generate.
197 startID : `int`
198 Initial SSObject ID.
199 flags : `int`
200 Value for ``flags`` column.
202 Returns
203 -------
204 catalog : `pandas.DataFrame`
205 Catalog of SSObjects records.
207 Notes
208 -----
209 Returned catalog only contains three columns - ``ssObjectId`, ``arc``,
210 and ``flags``.
211 """
212 ids = numpy.arange(start_id, count + start_id, dtype=numpy.int64)
213 arc = numpy.full(count, 0.001, dtype=numpy.float32)
214 flags_array = numpy.full(count, flags, dtype=numpy.int64)
215 df = pandas.DataFrame({"ssObjectId": ids, "arc": arc, "flags": flags_array})
216 return df