Coverage for python/lsst/dax/apdb/tests/_apdb.py: 15%
305 statements
« prev ^ index » next coverage.py v7.3.3, created at 2023-12-20 17:15 +0000
« prev ^ index » next coverage.py v7.3.3, created at 2023-12-20 17:15 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest"]
26from abc import ABC, abstractmethod
27from typing import TYPE_CHECKING, Any
29import pandas
30from lsst.daf.base import DateTime
31from lsst.dax.apdb import ApdbConfig, ApdbInsertId, ApdbSql, ApdbTableData, ApdbTables, make_apdb
32from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
34from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog
36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true
37 import unittest
39 class TestCaseMixin(unittest.TestCase):
40 """Base class for mixin test classes that use TestCase methods."""
42else:
44 class TestCaseMixin:
45 """Do-nothing definition of mixin base class for regular execution."""
48def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
49 """Make a region to use in tests"""
50 pointing_v = UnitVector3d(*xyz)
51 fov = 0.05 # radians
52 region = Circle(pointing_v, Angle(fov / 2))
53 return region
56class ApdbTest(TestCaseMixin, ABC):
57 """Base class for Apdb tests that can be specialized for concrete
58 implementation.
60 This can only be used as a mixin class for a unittest.TestCase and it
61 calls various assert methods.
62 """
64 time_partition_tables = False
65 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
67 fsrc_requires_id_list = False
68 """Should be set to True if getDiaForcedSources requires object IDs"""
70 use_insert_id: bool = False
71 """Set to true when support for Insert IDs is configured"""
73 allow_visit_query: bool = True
74 """Set to true when contains is implemented"""
76 # number of columns as defined in tests/config/schema.yaml
77 table_column_count = {
78 ApdbTables.DiaObject: 8,
79 ApdbTables.DiaObjectLast: 5,
80 ApdbTables.DiaSource: 10,
81 ApdbTables.DiaForcedSource: 4,
82 ApdbTables.SSObject: 3,
83 }
85 @abstractmethod
86 def make_config(self, **kwargs: Any) -> ApdbConfig:
87 """Make config class instance used in all tests."""
88 raise NotImplementedError()
90 @abstractmethod
91 def getDiaObjects_table(self) -> ApdbTables:
92 """Return type of table returned from getDiaObjects method."""
93 raise NotImplementedError()
95 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
96 """Validate catalog type and size
98 Parameters
99 ----------
100 catalog : `object`
101 Expected type of this is ``pandas.DataFrame``.
102 rows : `int`
103 Expected number of rows in a catalog.
104 table : `ApdbTables`
105 APDB table type.
106 """
107 self.assertIsInstance(catalog, pandas.DataFrame)
108 self.assertEqual(catalog.shape[0], rows)
109 self.assertEqual(catalog.shape[1], self.table_column_count[table])
111 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
112 """Validate catalog type and size
114 Parameters
115 ----------
116 catalog : `object`
117 Expected type of this is `ApdbTableData`.
118 rows : `int`
119 Expected number of rows in a catalog.
120 table : `ApdbTables`
121 APDB table type.
122 extra_columns : `int`
123 Count of additional columns expected in ``catalog``.
124 """
125 self.assertIsInstance(catalog, ApdbTableData)
126 n_rows = sum(1 for row in catalog.rows())
127 self.assertEqual(n_rows, rows)
128 # One extra column for insert_id
129 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1)
131 def test_makeSchema(self) -> None:
132 """Test for making APDB schema."""
133 config = self.make_config()
134 apdb = make_apdb(config)
136 apdb.makeSchema()
137 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
138 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
139 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
140 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
142 def test_empty_gets(self) -> None:
143 """Test for getting data from empty database.
145 All get() methods should return empty results, only useful for
146 checking that code is not broken.
147 """
148 # use non-zero months for Forced/Source fetching
149 config = self.make_config()
150 apdb = make_apdb(config)
151 apdb.makeSchema()
153 region = _make_region()
154 visit_time = self.visit_time
156 res: pandas.DataFrame | None
158 # get objects by region
159 res = apdb.getDiaObjects(region)
160 self.assert_catalog(res, 0, self.getDiaObjects_table())
162 # get sources by region
163 res = apdb.getDiaSources(region, None, visit_time)
164 self.assert_catalog(res, 0, ApdbTables.DiaSource)
166 res = apdb.getDiaSources(region, [], visit_time)
167 self.assert_catalog(res, 0, ApdbTables.DiaSource)
169 # get sources by object ID, non-empty object list
170 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
171 self.assert_catalog(res, 0, ApdbTables.DiaSource)
173 # get forced sources by object ID, empty object list
174 res = apdb.getDiaForcedSources(region, [], visit_time)
175 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
177 # get sources by object ID, non-empty object list
178 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
179 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
181 # test if a visit has objects/sources
182 if self.allow_visit_query:
183 res = apdb.containsVisitDetector(visit=0, detector=0)
184 self.assertFalse(res)
185 else:
186 with self.assertRaises(NotImplementedError):
187 apdb.containsVisitDetector(visit=0, detector=0)
189 # alternative method not part of the Apdb API
190 if isinstance(apdb, ApdbSql):
191 res = apdb.containsCcdVisit(1)
192 self.assertFalse(res)
194 # get sources by region
195 if self.fsrc_requires_id_list:
196 with self.assertRaises(NotImplementedError):
197 apdb.getDiaForcedSources(region, None, visit_time)
198 else:
199 apdb.getDiaForcedSources(region, None, visit_time)
200 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
202 def test_empty_gets_0months(self) -> None:
203 """Test for getting data from empty database.
205 All get() methods should return empty DataFrame or None.
206 """
207 # set read_sources_months to 0 so that Forced/Sources are None
208 config = self.make_config(read_sources_months=0, read_forced_sources_months=0)
209 apdb = make_apdb(config)
210 apdb.makeSchema()
212 region = _make_region()
213 visit_time = self.visit_time
215 res: pandas.DataFrame | None
217 # get objects by region
218 res = apdb.getDiaObjects(region)
219 self.assert_catalog(res, 0, self.getDiaObjects_table())
221 # get sources by region
222 res = apdb.getDiaSources(region, None, visit_time)
223 self.assertIs(res, None)
225 # get sources by object ID, empty object list
226 res = apdb.getDiaSources(region, [], visit_time)
227 self.assertIs(res, None)
229 # get forced sources by object ID, empty object list
230 res = apdb.getDiaForcedSources(region, [], visit_time)
231 self.assertIs(res, None)
233 # test if a visit has objects/sources
234 if self.allow_visit_query:
235 res = apdb.containsVisitDetector(visit=0, detector=0)
236 self.assertFalse(res)
237 else:
238 with self.assertRaises(NotImplementedError):
239 apdb.containsVisitDetector(visit=0, detector=0)
241 # alternative method not part of the Apdb API
242 if isinstance(apdb, ApdbSql):
243 res = apdb.containsCcdVisit(1)
244 self.assertFalse(res)
246 def test_storeObjects(self) -> None:
247 """Store and retrieve DiaObjects."""
248 # don't care about sources.
249 config = self.make_config()
250 apdb = make_apdb(config)
251 apdb.makeSchema()
253 region = _make_region()
254 visit_time = self.visit_time
256 # make catalog with Objects
257 catalog = makeObjectCatalog(region, 100, visit_time)
259 # store catalog
260 apdb.store(visit_time, catalog)
262 # read it back and check sizes
263 res = apdb.getDiaObjects(region)
264 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
266 # TODO: test apdb.contains with generic implementation from DM-41671
268 def test_storeSources(self) -> None:
269 """Store and retrieve DiaSources."""
270 config = self.make_config()
271 apdb = make_apdb(config)
272 apdb.makeSchema()
274 region = _make_region()
275 visit_time = self.visit_time
277 # have to store Objects first
278 objects = makeObjectCatalog(region, 100, visit_time)
279 oids = list(objects["diaObjectId"])
280 sources = makeSourceCatalog(objects, visit_time)
282 # save the objects and sources
283 apdb.store(visit_time, objects, sources)
285 # read it back, no ID filtering
286 res = apdb.getDiaSources(region, None, visit_time)
287 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
289 # read it back and filter by ID
290 res = apdb.getDiaSources(region, oids, visit_time)
291 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
293 # read it back to get schema
294 res = apdb.getDiaSources(region, [], visit_time)
295 self.assert_catalog(res, 0, ApdbTables.DiaSource)
297 # test if a visit is present
298 # data_factory's ccdVisitId generation corresponds to (0, 0)
299 if self.allow_visit_query:
300 res = apdb.containsVisitDetector(visit=0, detector=0)
301 self.assertTrue(res)
302 else:
303 with self.assertRaises(NotImplementedError):
304 apdb.containsVisitDetector(visit=0, detector=0)
306 # alternative method not part of the Apdb API
307 if isinstance(apdb, ApdbSql):
308 res = apdb.containsCcdVisit(1)
309 self.assertTrue(res)
310 res = apdb.containsCcdVisit(42)
311 self.assertFalse(res)
313 def test_storeForcedSources(self) -> None:
314 """Store and retrieve DiaForcedSources."""
315 config = self.make_config()
316 apdb = make_apdb(config)
317 apdb.makeSchema()
319 region = _make_region()
320 visit_time = self.visit_time
322 # have to store Objects first
323 objects = makeObjectCatalog(region, 100, visit_time)
324 oids = list(objects["diaObjectId"])
325 catalog = makeForcedSourceCatalog(objects, visit_time)
327 apdb.store(visit_time, objects, forced_sources=catalog)
329 # read it back and check sizes
330 res = apdb.getDiaForcedSources(region, oids, visit_time)
331 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
333 # read it back to get schema
334 res = apdb.getDiaForcedSources(region, [], visit_time)
335 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
337 # TODO: test apdb.contains with generic implementation from DM-41671
339 # alternative method not part of the Apdb API
340 if isinstance(apdb, ApdbSql):
341 res = apdb.containsCcdVisit(1)
342 self.assertTrue(res)
343 res = apdb.containsCcdVisit(42)
344 self.assertFalse(res)
346 def test_getHistory(self) -> None:
347 """Store and retrieve catalog history."""
348 # don't care about sources.
349 config = self.make_config()
350 apdb = make_apdb(config)
351 apdb.makeSchema()
352 visit_time = self.visit_time
354 region1 = _make_region((1.0, 1.0, -1.0))
355 region2 = _make_region((-1.0, -1.0, -1.0))
356 nobj = 100
357 objects1 = makeObjectCatalog(region1, nobj, visit_time)
358 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
360 visits = [
361 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1),
362 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2),
363 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1),
364 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2),
365 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1),
366 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2),
367 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1),
368 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2),
369 ]
371 start_id = 0
372 for visit_time, objects in visits:
373 sources = makeSourceCatalog(objects, visit_time, start_id=start_id)
374 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id)
375 apdb.store(visit_time, objects, sources, fsources)
376 start_id += nobj
378 insert_ids = apdb.getInsertIds()
379 if not self.use_insert_id:
380 self.assertIsNone(insert_ids)
382 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"):
383 apdb.getDiaObjectsHistory([])
385 else:
386 assert insert_ids is not None
387 self.assertEqual(len(insert_ids), 8)
389 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None:
390 if n_records is None:
391 n_records = len(insert_ids) * nobj
392 res = apdb.getDiaObjectsHistory(insert_ids)
393 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
394 res = apdb.getDiaSourcesHistory(insert_ids)
395 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
396 res = apdb.getDiaForcedSourcesHistory(insert_ids)
397 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
399 # read it back and check sizes
400 _check_history(insert_ids)
401 _check_history(insert_ids[1:])
402 _check_history(insert_ids[1:-1])
403 _check_history(insert_ids[3:4])
404 _check_history([])
406 # try to remove some of those
407 deleted_ids = insert_ids[:2]
408 apdb.deleteInsertIds(deleted_ids)
410 # All queries on deleted ids should return empty set.
411 _check_history(deleted_ids, 0)
413 insert_ids = apdb.getInsertIds()
414 assert insert_ids is not None
415 self.assertEqual(len(insert_ids), 6)
417 _check_history(insert_ids)
419 def test_storeSSObjects(self) -> None:
420 """Store and retrieve SSObjects."""
421 # don't care about sources.
422 config = self.make_config()
423 apdb = make_apdb(config)
424 apdb.makeSchema()
426 # make catalog with SSObjects
427 catalog = makeSSObjectCatalog(100, flags=1)
429 # store catalog
430 apdb.storeSSObjects(catalog)
432 # read it back and check sizes
433 res = apdb.getSSObjects()
434 self.assert_catalog(res, len(catalog), ApdbTables.SSObject)
436 # check that override works, make catalog with SSObjects, ID = 51-150
437 catalog = makeSSObjectCatalog(100, 51, flags=2)
438 apdb.storeSSObjects(catalog)
439 res = apdb.getSSObjects()
440 self.assert_catalog(res, 150, ApdbTables.SSObject)
441 self.assertEqual(len(res[res["flags"] == 1]), 50)
442 self.assertEqual(len(res[res["flags"] == 2]), 100)
444 def test_reassignObjects(self) -> None:
445 """Reassign DiaObjects."""
446 # don't care about sources.
447 config = self.make_config()
448 apdb = make_apdb(config)
449 apdb.makeSchema()
451 region = _make_region()
452 visit_time = self.visit_time
453 objects = makeObjectCatalog(region, 100, visit_time)
454 oids = list(objects["diaObjectId"])
455 sources = makeSourceCatalog(objects, visit_time)
456 apdb.store(visit_time, objects, sources)
458 catalog = makeSSObjectCatalog(100)
459 apdb.storeSSObjects(catalog)
461 # read it back and filter by ID
462 res = apdb.getDiaSources(region, oids, visit_time)
463 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
465 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
466 res = apdb.getDiaSources(region, oids, visit_time)
467 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
469 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
470 apdb.reassignDiaSources(
471 {
472 1000: 1,
473 7: 3,
474 }
475 )
476 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
478 def test_midpointMjdTai_src(self) -> None:
479 """Test for time filtering of DiaSources."""
480 config = self.make_config()
481 apdb = make_apdb(config)
482 apdb.makeSchema()
484 region = _make_region()
485 # 2021-01-01 plus 360 days is 2021-12-27
486 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
487 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
488 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
489 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
490 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
492 objects = makeObjectCatalog(region, 100, visit_time0)
493 oids = list(objects["diaObjectId"])
494 sources = makeSourceCatalog(objects, src_time1, 0)
495 apdb.store(src_time1, objects, sources)
497 sources = makeSourceCatalog(objects, src_time2, 100)
498 apdb.store(src_time2, objects, sources)
500 # reading at time of last save should read all
501 res = apdb.getDiaSources(region, oids, src_time2)
502 self.assert_catalog(res, 200, ApdbTables.DiaSource)
504 # one second before 12 months
505 res = apdb.getDiaSources(region, oids, visit_time0)
506 self.assert_catalog(res, 200, ApdbTables.DiaSource)
508 # reading at later time of last save should only read a subset
509 res = apdb.getDiaSources(region, oids, visit_time1)
510 self.assert_catalog(res, 100, ApdbTables.DiaSource)
512 # reading at later time of last save should only read a subset
513 res = apdb.getDiaSources(region, oids, visit_time2)
514 self.assert_catalog(res, 0, ApdbTables.DiaSource)
516 def test_midpointMjdTai_fsrc(self) -> None:
517 """Test for time filtering of DiaForcedSources."""
518 config = self.make_config()
519 apdb = make_apdb(config)
520 apdb.makeSchema()
522 region = _make_region()
523 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
524 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
525 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
526 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
527 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
529 objects = makeObjectCatalog(region, 100, visit_time0)
530 oids = list(objects["diaObjectId"])
531 sources = makeForcedSourceCatalog(objects, src_time1, 1)
532 apdb.store(src_time1, objects, forced_sources=sources)
534 sources = makeForcedSourceCatalog(objects, src_time2, 2)
535 apdb.store(src_time2, objects, forced_sources=sources)
537 # reading at time of last save should read all
538 res = apdb.getDiaForcedSources(region, oids, src_time2)
539 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
541 # one second before 12 months
542 res = apdb.getDiaForcedSources(region, oids, visit_time0)
543 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
545 # reading at later time of last save should only read a subset
546 res = apdb.getDiaForcedSources(region, oids, visit_time1)
547 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
549 # reading at later time of last save should only read a subset
550 res = apdb.getDiaForcedSources(region, oids, visit_time2)
551 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
554class ApdbSchemaUpdateTest(TestCaseMixin, ABC):
555 """Base class for unit tests that verify how schema changes work."""
557 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
559 @abstractmethod
560 def make_config(self, **kwargs: Any) -> ApdbConfig:
561 """Make config class instance used in all tests.
563 This method should return configuration that point to the identical
564 database instance on each call (i.e. ``db_url`` must be the same,
565 which also means for sqlite it has to use on-disk storage).
566 """
567 raise NotImplementedError()
569 def test_schema_add_history(self) -> None:
570 """Check that new code can work with old schema without history
571 tables.
572 """
573 # Make schema without history tables.
574 config = self.make_config(use_insert_id=False)
575 apdb = make_apdb(config)
576 apdb.makeSchema()
578 # Make APDB instance configured for history tables.
579 config = self.make_config(use_insert_id=True)
580 apdb = make_apdb(config)
582 # Try to insert something, should work OK.
583 region = _make_region()
584 visit_time = self.visit_time
586 # have to store Objects first
587 objects = makeObjectCatalog(region, 100, visit_time)
588 sources = makeSourceCatalog(objects, visit_time)
589 fsources = makeForcedSourceCatalog(objects, visit_time)
590 apdb.store(visit_time, objects, sources, fsources)
592 # There should be no history.
593 insert_ids = apdb.getInsertIds()
594 self.assertIsNone(insert_ids)