Coverage for tests/test_apdbSql.py : 15%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Unit test for Apdb class.
23"""
25import pandas
26import random
27from typing import Iterator
28import unittest
30from lsst.daf.base import DateTime
31from lsst.dax.apdb import ApdbSql, ApdbSqlConfig
32from lsst.sphgeom import Angle, Circle, LonLat, Region, UnitVector3d
33from lsst.geom import SpherePoint
34import lsst.utils.tests
37def _makeRegion() -> Region:
38 """Generate pixel ID ranges for some envelope region"""
39 pointing_v = UnitVector3d(1., 1., -1.)
40 fov = 0.05 # radians
41 region = Circle(pointing_v, Angle(fov/2))
42 return region
45def _makeVectors(region: Region, count: int = 1) -> Iterator[SpherePoint]:
46 """Generate bunch of SpherePoints inside given region.
48 Returned vectors are random but not necessarily uniformly distributed.
49 """
50 bbox = region.getBoundingBox()
51 center = bbox.getCenter()
52 center_lon = center.getLon().asRadians()
53 center_lat = center.getLat().asRadians()
54 width = bbox.getWidth().asRadians()
55 height = bbox.getHeight().asRadians()
56 while count > 0:
57 lon = random.uniform(center_lon - width / 2, center_lon + width / 2)
58 lat = random.uniform(center_lat - height / 2, center_lat + height / 2)
59 lonlat = LonLat.fromRadians(lon, lat)
60 uv3d = UnitVector3d(lonlat)
61 if region.contains(uv3d):
62 yield SpherePoint(lonlat)
63 count -= 1
66def _makeObjectCatalogPandas(region, count: int, config: ApdbSqlConfig):
67 """Make a catalog containing a bunch of DiaObjects inside region.
69 The number of created records will be equal to the number of ranges (one
70 object per pixel range). Coordinates of the created objects are not usable.
71 """
72 data_list = []
73 for oid, sp in enumerate(_makeVectors(region, count)):
74 tmp_dict = {"diaObjectId": oid,
75 "ra": sp.getRa().asDegrees(),
76 "decl": sp.getDec().asDegrees()}
77 data_list.append(tmp_dict)
79 df = pandas.DataFrame(data=data_list)
80 return df
83def _makeSourceCatalogPandas(objects, visit_time, start_id=0):
84 """Make a catalog containing a bunch of DiaSources associated with the
85 input diaObjects.
86 """
87 # make some sources
88 catalog = []
89 midPointTai = visit_time.get(system=DateTime.MJD)
90 for index, obj in objects.iterrows():
91 catalog.append({"diaSourceId": start_id,
92 "ccdVisitId": 1,
93 "diaObjectId": obj["diaObjectId"],
94 "parentDiaSourceId": 0,
95 "ra": obj["ra"],
96 "decl": obj["decl"],
97 "midPointTai": midPointTai,
98 "flags": 0})
99 start_id += 1
100 return pandas.DataFrame(data=catalog)
103def _makeForcedSourceCatalogPandas(objects, visit_time, ccdVisitId=1):
104 """Make a catalog containing a bunch of DiaFourceSources associated with
105 the input diaObjects.
106 """
107 # make some sources
108 catalog = []
109 midPointTai = visit_time.get(system=DateTime.MJD)
110 for index, obj in objects.iterrows():
111 catalog.append({"diaObjectId": obj["diaObjectId"],
112 "ccdVisitId": ccdVisitId,
113 "midPointTai": midPointTai,
114 "flags": 0})
115 return pandas.DataFrame(data=catalog)
118class ApdbTestCase(unittest.TestCase):
119 """A test case for Apdb class
120 """
122 data_type = pandas.DataFrame
124 def _assertCatalog(self, catalog, size, type=pandas.DataFrame):
125 """Validate catalog type and size
127 Parameters
128 ----------
129 calalog : `object`
130 Expected type of this is ``type``.
131 size : int
132 Expected catalog size
133 type : `type`, optional
134 Expected catalog type
135 """
136 self.assertIsInstance(catalog, type)
137 self.assertEqual(len(catalog), size)
139 def test_makeSchema(self):
140 """Test for making an instance of Apdb using in-memory sqlite engine.
141 """
142 # sqlite does not support default READ_COMMITTED, for in-memory
143 # database have to use connection pool
144 config = ApdbSqlConfig(db_url="sqlite://")
145 apdb = ApdbSql(config)
146 # the essence of a test here is that there are no exceptions.
147 apdb.makeSchema()
149 def test_emptyGetsBaseline0months(self):
150 """Test for getting data from empty database.
152 All get() methods should return empty results, only useful for
153 checking that code is not broken.
154 """
156 # set read_sources_months to 0 so that Forced/Sources are None
157 config = ApdbSqlConfig(db_url="sqlite:///",
158 read_sources_months=0,
159 read_forced_sources_months=0)
160 apdb = ApdbSql(config)
161 apdb.makeSchema()
163 region = _makeRegion()
164 visit_time = DateTime.now()
166 # get objects by region
167 res = apdb.getDiaObjects(region)
168 self._assertCatalog(res, 0, type=self.data_type)
170 # get sources by region
171 res = apdb.getDiaSources(region, None, visit_time)
172 self.assertIs(res, None)
174 # get sources by object ID, empty object list
175 res = apdb.getDiaSources(region, [], visit_time)
176 self.assertIs(res, None)
178 # get forced sources by object ID, empty object list
179 res = apdb.getDiaForcedSources(region, [], visit_time)
180 self.assertIs(res, None)
182 def test_emptyGetsBaseline(self):
183 """Test for getting data from empty database.
185 All get() methods should return empty results, only useful for
186 checking that code is not broken.
187 """
189 # use non-zero months for Forced/Source fetching
190 config = ApdbSqlConfig(db_url="sqlite:///",
191 read_sources_months=12,
192 read_forced_sources_months=12)
193 apdb = ApdbSql(config)
194 apdb.makeSchema()
196 region = _makeRegion()
197 visit_time = DateTime.now()
199 # get objects by region
200 res = apdb.getDiaObjects(region)
201 self._assertCatalog(res, 0, type=self.data_type)
203 # get sources by region
204 res = apdb.getDiaSources(region, None, visit_time)
205 self._assertCatalog(res, 0, type=self.data_type)
207 res = apdb.getDiaSources(region, [], visit_time)
208 self._assertCatalog(res, 0, type=self.data_type)
210 # get sources by object ID, non-empty object list
211 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
212 self._assertCatalog(res, 0, type=self.data_type)
214 # get forced sources by object ID, empty object list
215 res = apdb.getDiaForcedSources(region, [], visit_time)
216 self._assertCatalog(res, 0, type=self.data_type)
218 # get sources by object ID, non-empty object list
219 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
220 self._assertCatalog(res, 0, type=self.data_type)
222 # SQL implementation needs ID list
223 with self.assertRaises(NotImplementedError):
224 apdb.getDiaForcedSources(region, None, visit_time)
226 def test_emptyGetsObjectLast(self):
227 """Test for getting DiaObjects from empty database using DiaObjectLast
228 table.
230 All get() methods should return empty results, only useful for
231 checking that code is not broken.
232 """
234 # don't care about sources.
235 config = ApdbSqlConfig(db_url="sqlite:///",
236 dia_object_index="last_object_table")
237 apdb = ApdbSql(config)
238 apdb.makeSchema()
240 region = _makeRegion()
242 # get objects by region
243 res = apdb.getDiaObjects(region)
244 self._assertCatalog(res, 0, type=self.data_type)
246 def test_storeObjectsBaseline(self):
247 """Store and retrieve DiaObjects."""
249 # don't care about sources.
250 config = ApdbSqlConfig(db_url="sqlite:///",
251 dia_object_index="baseline")
252 apdb = ApdbSql(config)
253 apdb.makeSchema()
255 region = _makeRegion()
256 visit_time = DateTime.now()
258 # make catalog with Objects
259 catalog = _makeObjectCatalogPandas(region, 100, config)
261 # store catalog
262 apdb.store(visit_time, catalog)
264 # read it back and check sizes
265 res = apdb.getDiaObjects(region)
266 self._assertCatalog(res, len(catalog), type=self.data_type)
268 def test_storeObjectsLast(self):
269 """Store and retrieve DiaObjects using DiaObjectLast table."""
270 # don't care about sources.
271 config = ApdbSqlConfig(db_url="sqlite:///",
272 dia_object_index="last_object_table",
273 object_last_replace=True)
274 apdb = ApdbSql(config)
275 apdb.makeSchema()
277 region = _makeRegion()
278 visit_time = DateTime.now()
280 # make catalog with Objects
281 catalog = _makeObjectCatalogPandas(region, 100, config)
283 # store catalog
284 apdb.store(visit_time, catalog)
286 # read it back and check sizes
287 res = apdb.getDiaObjects(region)
288 self._assertCatalog(res, len(catalog), type=self.data_type)
290 def test_storeSources(self):
291 """Store and retrieve DiaSources."""
292 config = ApdbSqlConfig(db_url="sqlite:///",
293 read_sources_months=12,
294 read_forced_sources_months=12)
295 apdb = ApdbSql(config)
296 apdb.makeSchema()
298 region = _makeRegion()
299 visit_time = DateTime.now()
301 # have to store Objects first
302 objects = _makeObjectCatalogPandas(region, 100, config)
303 oids = list(objects["diaObjectId"])
304 sources = _makeSourceCatalogPandas(objects, visit_time)
306 # save the objects and sources
307 apdb.store(visit_time, objects, sources)
309 # read it back, no ID filtering
310 res = apdb.getDiaSources(region, None, visit_time)
311 self._assertCatalog(res, len(sources), type=self.data_type)
313 # read it back and filter by ID
314 res = apdb.getDiaSources(region, oids, visit_time)
315 self._assertCatalog(res, len(sources), type=self.data_type)
317 # read it back to get schema
318 res = apdb.getDiaSources(region, [], visit_time)
319 self._assertCatalog(res, 0, type=self.data_type)
321 def test_storeForcedSources(self):
322 """Store and retrieve DiaForcedSources."""
324 config = ApdbSqlConfig(db_url="sqlite:///",
325 read_sources_months=12,
326 read_forced_sources_months=12)
327 apdb = ApdbSql(config)
328 apdb.makeSchema()
330 region = _makeRegion()
331 visit_time = DateTime.now()
333 # have to store Objects first
334 objects = _makeObjectCatalogPandas(region, 100, config)
335 oids = list(objects["diaObjectId"])
336 catalog = _makeForcedSourceCatalogPandas(objects, visit_time)
338 apdb.store(visit_time, objects, forced_sources=catalog)
340 # read it back and check sizes
341 res = apdb.getDiaForcedSources(region, oids, visit_time)
342 self._assertCatalog(res, len(catalog), type=self.data_type)
344 # read it back to get schema
345 res = apdb.getDiaForcedSources(region, [], visit_time)
346 self._assertCatalog(res, 0, type=self.data_type)
348 def test_midPointTai_src(self):
349 """Test for time filtering of DiaSources.
350 """
351 config = ApdbSqlConfig(db_url="sqlite:///",
352 read_sources_months=12,
353 read_forced_sources_months=12)
354 apdb = ApdbSql(config)
355 apdb.makeSchema()
357 region = _makeRegion()
358 # 2021-01-01 plus 360 days is 2021-12-27
359 src_time1 = DateTime(2021, 1, 1, 0, 0, 0, DateTime.TAI)
360 src_time2 = DateTime(2021, 1, 1, 0, 0, 2, DateTime.TAI)
361 visit_time0 = DateTime(2021, 12, 26, 23, 59, 59, DateTime.TAI)
362 visit_time1 = DateTime(2021, 12, 27, 0, 0, 1, DateTime.TAI)
363 visit_time2 = DateTime(2021, 12, 27, 0, 0, 3, DateTime.TAI)
365 objects = _makeObjectCatalogPandas(region, 100, config)
366 oids = list(objects["diaObjectId"])
367 sources = _makeSourceCatalogPandas(objects, src_time1, 0)
368 apdb.store(src_time1, objects, sources)
370 sources = _makeSourceCatalogPandas(objects, src_time2, 100)
371 apdb.store(src_time2, objects, sources)
373 # reading at time of last save should read all
374 res = apdb.getDiaSources(region, oids, src_time2)
375 self._assertCatalog(res, 200, type=self.data_type)
377 # one second before 12 months
378 res = apdb.getDiaSources(region, oids, visit_time0)
379 self._assertCatalog(res, 200, type=self.data_type)
381 # reading at later time of last save should only read a subset
382 res = apdb.getDiaSources(region, oids, visit_time1)
383 self._assertCatalog(res, 100, type=self.data_type)
385 # reading at later time of last save should only read a subset
386 res = apdb.getDiaSources(region, oids, visit_time2)
387 self._assertCatalog(res, 0, type=self.data_type)
389 def test_midPointTai_fsrc(self):
390 """Test for time filtering of DiaForcedSources.
391 """
392 config = ApdbSqlConfig(db_url="sqlite:///",
393 read_sources_months=12,
394 read_forced_sources_months=12)
395 apdb = ApdbSql(config)
396 apdb.makeSchema()
398 region = _makeRegion()
399 src_time1 = DateTime(2021, 1, 1, 0, 0, 0, DateTime.TAI)
400 src_time2 = DateTime(2021, 1, 1, 0, 0, 2, DateTime.TAI)
401 visit_time0 = DateTime(2021, 12, 26, 23, 59, 59, DateTime.TAI)
402 visit_time1 = DateTime(2021, 12, 27, 0, 0, 1, DateTime.TAI)
403 visit_time2 = DateTime(2021, 12, 27, 0, 0, 3, DateTime.TAI)
405 objects = _makeObjectCatalogPandas(region, 100, config)
406 oids = list(objects["diaObjectId"])
407 sources = _makeForcedSourceCatalogPandas(objects, src_time1, 1)
408 apdb.store(src_time1, objects, forced_sources=sources)
410 sources = _makeForcedSourceCatalogPandas(objects, src_time2, 2)
411 apdb.store(src_time2, objects, forced_sources=sources)
413 # reading at time of last save should read all
414 res = apdb.getDiaForcedSources(region, oids, src_time2)
415 self._assertCatalog(res, 200, type=self.data_type)
417 # one second before 12 months
418 res = apdb.getDiaForcedSources(region, oids, visit_time0)
419 self._assertCatalog(res, 200, type=self.data_type)
421 # reading at later time of last save should only read a subset
422 res = apdb.getDiaForcedSources(region, oids, visit_time1)
423 self._assertCatalog(res, 100, type=self.data_type)
425 # reading at later time of last save should only read a subset
426 res = apdb.getDiaForcedSources(region, oids, visit_time2)
427 self._assertCatalog(res, 0, type=self.data_type)
430class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
431 pass
434def setup_module(module):
435 lsst.utils.tests.init()
438if __name__ == "__main__": 438 ↛ 439line 438 didn't jump to line 439, because the condition on line 438 was never true
439 lsst.utils.tests.init()
440 unittest.main()