Coverage for python/lsst/dax/apdb/tests/_apdb.py: 17%
276 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-16 10:26 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-16 10:26 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest"]
26from abc import ABC, abstractmethod
27from collections.abc import Callable
28from typing import TYPE_CHECKING, Any, ContextManager, Optional
30import pandas
31from lsst.daf.base import DateTime
32from lsst.dax.apdb import ApdbConfig, ApdbInsertId, ApdbTableData, ApdbTables, make_apdb
33from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
35from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog
38def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
39 """Make a region to use in tests"""
40 pointing_v = UnitVector3d(*xyz)
41 fov = 0.05 # radians
42 region = Circle(pointing_v, Angle(fov / 2))
43 return region
46class ApdbTest(ABC):
47 """Base class for Apdb tests that can be specialized for concrete
48 implementation.
50 This can only be used as a mixin class for a unittest.TestCase and it
51 calls various assert methods.
52 """
54 time_partition_tables = False
55 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
57 fsrc_requires_id_list = False
58 """Should be set to True if getDiaForcedSources requires object IDs"""
60 use_insert_id: bool = False
61 """Set to true when support for Insert IDs is configured"""
63 # number of columns as defined in tests/config/schema.yaml
64 table_column_count = {
65 ApdbTables.DiaObject: 8,
66 ApdbTables.DiaObjectLast: 5,
67 ApdbTables.DiaSource: 10,
68 ApdbTables.DiaForcedSource: 4,
69 ApdbTables.SSObject: 3,
70 }
72 @abstractmethod
73 def make_config(self, **kwargs: Any) -> ApdbConfig:
74 """Make config class instance used in all tests."""
75 raise NotImplementedError()
77 @abstractmethod
78 def getDiaObjects_table(self) -> ApdbTables:
79 """Return type of table returned from getDiaObjects method."""
80 raise NotImplementedError()
82 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
83 """Validate catalog type and size
85 Parameters
86 ----------
87 catalog : `object`
88 Expected type of this is ``pandas.DataFrame``.
89 rows : `int`
90 Expected number of rows in a catalog.
91 table : `ApdbTables`
92 APDB table type.
93 """
94 self.assertIsInstance(catalog, pandas.DataFrame)
95 self.assertEqual(catalog.shape[0], rows)
96 self.assertEqual(catalog.shape[1], self.table_column_count[table])
98 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
99 """Validate catalog type and size
101 Parameters
102 ----------
103 catalog : `object`
104 Expected type of this is `ApdbTableData`.
105 rows : `int`
106 Expected number of rows in a catalog.
107 table : `ApdbTables`
108 APDB table type.
109 extra_columns : `int`
110 Count of additional columns expected in ``catalog``.
111 """
112 self.assertIsInstance(catalog, ApdbTableData)
113 n_rows = sum(1 for row in catalog.rows())
114 self.assertEqual(n_rows, rows)
115 # One extra column for insert_id
116 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1)
118 def test_makeSchema(self) -> None:
119 """Test for makeing APDB schema."""
120 config = self.make_config()
121 apdb = make_apdb(config)
123 apdb.makeSchema()
124 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
125 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
126 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
127 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
129 def test_empty_gets(self) -> None:
130 """Test for getting data from empty database.
132 All get() methods should return empty results, only useful for
133 checking that code is not broken.
134 """
136 # use non-zero months for Forced/Source fetching
137 config = self.make_config()
138 apdb = make_apdb(config)
139 apdb.makeSchema()
141 region = _make_region()
142 visit_time = self.visit_time
144 res: Optional[pandas.DataFrame]
146 # get objects by region
147 res = apdb.getDiaObjects(region)
148 self.assert_catalog(res, 0, self.getDiaObjects_table())
150 # get sources by region
151 res = apdb.getDiaSources(region, None, visit_time)
152 self.assert_catalog(res, 0, ApdbTables.DiaSource)
154 res = apdb.getDiaSources(region, [], visit_time)
155 self.assert_catalog(res, 0, ApdbTables.DiaSource)
157 # get sources by object ID, non-empty object list
158 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
159 self.assert_catalog(res, 0, ApdbTables.DiaSource)
161 # get forced sources by object ID, empty object list
162 res = apdb.getDiaForcedSources(region, [], visit_time)
163 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
165 # get sources by object ID, non-empty object list
166 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
167 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
169 # get sources by region
170 if self.fsrc_requires_id_list:
171 with self.assertRaises(NotImplementedError):
172 apdb.getDiaForcedSources(region, None, visit_time)
173 else:
174 apdb.getDiaForcedSources(region, None, visit_time)
175 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
177 def test_empty_gets_0months(self) -> None:
178 """Test for getting data from empty database.
180 All get() methods should return empty DataFrame or None.
181 """
183 # set read_sources_months to 0 so that Forced/Sources are None
184 config = self.make_config(read_sources_months=0, read_forced_sources_months=0)
185 apdb = make_apdb(config)
186 apdb.makeSchema()
188 region = _make_region()
189 visit_time = self.visit_time
191 res: Optional[pandas.DataFrame]
193 # get objects by region
194 res = apdb.getDiaObjects(region)
195 self.assert_catalog(res, 0, self.getDiaObjects_table())
197 # get sources by region
198 res = apdb.getDiaSources(region, None, visit_time)
199 self.assertIs(res, None)
201 # get sources by object ID, empty object list
202 res = apdb.getDiaSources(region, [], visit_time)
203 self.assertIs(res, None)
205 # get forced sources by object ID, empty object list
206 res = apdb.getDiaForcedSources(region, [], visit_time)
207 self.assertIs(res, None)
209 def test_storeObjects(self) -> None:
210 """Store and retrieve DiaObjects."""
212 # don't care about sources.
213 config = self.make_config()
214 apdb = make_apdb(config)
215 apdb.makeSchema()
217 region = _make_region()
218 visit_time = self.visit_time
220 # make catalog with Objects
221 catalog = makeObjectCatalog(region, 100, visit_time)
223 # store catalog
224 apdb.store(visit_time, catalog)
226 # read it back and check sizes
227 res = apdb.getDiaObjects(region)
228 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
230 def test_storeSources(self) -> None:
231 """Store and retrieve DiaSources."""
232 config = self.make_config()
233 apdb = make_apdb(config)
234 apdb.makeSchema()
236 region = _make_region()
237 visit_time = self.visit_time
239 # have to store Objects first
240 objects = makeObjectCatalog(region, 100, visit_time)
241 oids = list(objects["diaObjectId"])
242 sources = makeSourceCatalog(objects, visit_time)
244 # save the objects and sources
245 apdb.store(visit_time, objects, sources)
247 # read it back, no ID filtering
248 res = apdb.getDiaSources(region, None, visit_time)
249 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
251 # read it back and filter by ID
252 res = apdb.getDiaSources(region, oids, visit_time)
253 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
255 # read it back to get schema
256 res = apdb.getDiaSources(region, [], visit_time)
257 self.assert_catalog(res, 0, ApdbTables.DiaSource)
259 def test_storeForcedSources(self) -> None:
260 """Store and retrieve DiaForcedSources."""
262 config = self.make_config()
263 apdb = make_apdb(config)
264 apdb.makeSchema()
266 region = _make_region()
267 visit_time = self.visit_time
269 # have to store Objects first
270 objects = makeObjectCatalog(region, 100, visit_time)
271 oids = list(objects["diaObjectId"])
272 catalog = makeForcedSourceCatalog(objects, visit_time)
274 apdb.store(visit_time, objects, forced_sources=catalog)
276 # read it back and check sizes
277 res = apdb.getDiaForcedSources(region, oids, visit_time)
278 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
280 # read it back to get schema
281 res = apdb.getDiaForcedSources(region, [], visit_time)
282 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
284 def test_getHistory(self) -> None:
285 """Store and retrieve catalog history."""
287 # don't care about sources.
288 config = self.make_config()
289 apdb = make_apdb(config)
290 apdb.makeSchema()
291 visit_time = self.visit_time
293 region1 = _make_region((1.0, 1.0, -1.0))
294 region2 = _make_region((-1.0, -1.0, -1.0))
295 nobj = 100
296 objects1 = makeObjectCatalog(region1, nobj, visit_time)
297 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
299 visits = [
300 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1),
301 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2),
302 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1),
303 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2),
304 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1),
305 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2),
306 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1),
307 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2),
308 ]
310 start_id = 0
311 for visit_time, objects in visits:
312 sources = makeSourceCatalog(objects, visit_time, start_id=start_id)
313 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id)
314 apdb.store(visit_time, objects, sources, fsources)
315 start_id += nobj
317 insert_ids = apdb.getInsertIds()
318 if not self.use_insert_id:
319 self.assertIsNone(insert_ids)
321 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"):
322 apdb.getDiaObjectsHistory([])
324 else:
325 assert insert_ids is not None
326 self.assertEqual(len(insert_ids), 8)
328 def _check_history(insert_ids: list[ApdbInsertId]) -> None:
329 n_records = len(insert_ids) * nobj
330 res = apdb.getDiaObjectsHistory(insert_ids)
331 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
332 res = apdb.getDiaSourcesHistory(insert_ids)
333 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
334 res = apdb.getDiaForcedSourcesHistory(insert_ids)
335 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
337 # read it back and check sizes
338 _check_history(insert_ids)
339 _check_history(insert_ids[1:])
340 _check_history(insert_ids[1:-1])
341 _check_history(insert_ids[3:4])
342 _check_history([])
344 # try to remove some of those
345 apdb.deleteInsertIds(insert_ids[:2])
346 insert_ids = apdb.getInsertIds()
347 assert insert_ids is not None
348 self.assertEqual(len(insert_ids), 6)
350 _check_history(insert_ids)
352 def test_storeSSObjects(self) -> None:
353 """Store and retrieve SSObjects."""
355 # don't care about sources.
356 config = self.make_config()
357 apdb = make_apdb(config)
358 apdb.makeSchema()
360 # make catalog with SSObjects
361 catalog = makeSSObjectCatalog(100, flags=1)
363 # store catalog
364 apdb.storeSSObjects(catalog)
366 # read it back and check sizes
367 res = apdb.getSSObjects()
368 self.assert_catalog(res, len(catalog), ApdbTables.SSObject)
370 # check that override works, make catalog with SSObjects, ID = 51-150
371 catalog = makeSSObjectCatalog(100, 51, flags=2)
372 apdb.storeSSObjects(catalog)
373 res = apdb.getSSObjects()
374 self.assert_catalog(res, 150, ApdbTables.SSObject)
375 self.assertEqual(len(res[res["flags"] == 1]), 50)
376 self.assertEqual(len(res[res["flags"] == 2]), 100)
378 def test_reassignObjects(self) -> None:
379 """Reassign DiaObjects."""
381 # don't care about sources.
382 config = self.make_config()
383 apdb = make_apdb(config)
384 apdb.makeSchema()
386 region = _make_region()
387 visit_time = self.visit_time
388 objects = makeObjectCatalog(region, 100, visit_time)
389 oids = list(objects["diaObjectId"])
390 sources = makeSourceCatalog(objects, visit_time)
391 apdb.store(visit_time, objects, sources)
393 catalog = makeSSObjectCatalog(100)
394 apdb.storeSSObjects(catalog)
396 # read it back and filter by ID
397 res = apdb.getDiaSources(region, oids, visit_time)
398 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
400 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
401 res = apdb.getDiaSources(region, oids, visit_time)
402 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
404 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
405 apdb.reassignDiaSources(
406 {
407 1000: 1,
408 7: 3,
409 }
410 )
411 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
413 def test_midpointMjdTai_src(self) -> None:
414 """Test for time filtering of DiaSources."""
415 config = self.make_config()
416 apdb = make_apdb(config)
417 apdb.makeSchema()
419 region = _make_region()
420 # 2021-01-01 plus 360 days is 2021-12-27
421 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
422 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
423 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
424 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
425 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
427 objects = makeObjectCatalog(region, 100, visit_time0)
428 oids = list(objects["diaObjectId"])
429 sources = makeSourceCatalog(objects, src_time1, 0)
430 apdb.store(src_time1, objects, sources)
432 sources = makeSourceCatalog(objects, src_time2, 100)
433 apdb.store(src_time2, objects, sources)
435 # reading at time of last save should read all
436 res = apdb.getDiaSources(region, oids, src_time2)
437 self.assert_catalog(res, 200, ApdbTables.DiaSource)
439 # one second before 12 months
440 res = apdb.getDiaSources(region, oids, visit_time0)
441 self.assert_catalog(res, 200, ApdbTables.DiaSource)
443 # reading at later time of last save should only read a subset
444 res = apdb.getDiaSources(region, oids, visit_time1)
445 self.assert_catalog(res, 100, ApdbTables.DiaSource)
447 # reading at later time of last save should only read a subset
448 res = apdb.getDiaSources(region, oids, visit_time2)
449 self.assert_catalog(res, 0, ApdbTables.DiaSource)
451 def test_midpointMjdTai_fsrc(self) -> None:
452 """Test for time filtering of DiaForcedSources."""
453 config = self.make_config()
454 apdb = make_apdb(config)
455 apdb.makeSchema()
457 region = _make_region()
458 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
459 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
460 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
461 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
462 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
464 objects = makeObjectCatalog(region, 100, visit_time0)
465 oids = list(objects["diaObjectId"])
466 sources = makeForcedSourceCatalog(objects, src_time1, 1)
467 apdb.store(src_time1, objects, forced_sources=sources)
469 sources = makeForcedSourceCatalog(objects, src_time2, 2)
470 apdb.store(src_time2, objects, forced_sources=sources)
472 # reading at time of last save should read all
473 res = apdb.getDiaForcedSources(region, oids, src_time2)
474 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
476 # one second before 12 months
477 res = apdb.getDiaForcedSources(region, oids, visit_time0)
478 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
480 # reading at later time of last save should only read a subset
481 res = apdb.getDiaForcedSources(region, oids, visit_time1)
482 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
484 # reading at later time of last save should only read a subset
485 res = apdb.getDiaForcedSources(region, oids, visit_time2)
486 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
488 if TYPE_CHECKING: 488 ↛ 491line 488 didn't jump to line 491, because the condition on line 488 was never true
489 # This is a mixin class, some methods from unittest.TestCase declared
490 # here to silence mypy.
491 assertEqual: Callable[[Any, Any], None]
492 assertIs: Callable[[Any, Any], None]
493 assertIsInstance: Callable[[Any, Any], None]
494 assertIsNone: Callable[[Any], None]
495 assertIsNotNone: Callable[[Any], None]
496 assertRaises: Callable[[Any], ContextManager]
497 assertRaisesRegex: Callable[[Any, Any], ContextManager]
500class ApdbSchemaUpdateTest(ABC):
501 """Base class for unit tests that verify how schema changes work."""
503 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
505 @abstractmethod
506 def make_config(self, **kwargs: Any) -> ApdbConfig:
507 """Make config class instance used in all tests.
509 This method should return configuration that point to the identical
510 database instance on each call (i.e. ``db_url`` must be the same,
511 which also means for sqlite it has to use on-disk storage).
512 """
513 raise NotImplementedError()
515 def test_schema_add_history(self) -> None:
516 """Check that new code can work with old schema without history
517 tables.
518 """
520 # Make schema without history tables.
521 config = self.make_config(use_insert_id=False)
522 apdb = make_apdb(config)
523 apdb.makeSchema()
525 # Make APDB instance configured for history tables.
526 config = self.make_config(use_insert_id=True)
527 apdb = make_apdb(config)
529 # Try to insert something, should work OK.
530 region = _make_region()
531 visit_time = self.visit_time
533 # have to store Objects first
534 objects = makeObjectCatalog(region, 100, visit_time)
535 sources = makeSourceCatalog(objects, visit_time)
536 fsources = makeForcedSourceCatalog(objects, visit_time)
537 apdb.store(visit_time, objects, sources, fsources)
539 # There should be no history.
540 insert_ids = apdb.getInsertIds()
541 self.assertIsNone(insert_ids)
543 if TYPE_CHECKING: 543 ↛ 546line 543 didn't jump to line 546, because the condition on line 543 was never true
544 # This is a mixin class, some methods from unittest.TestCase declared
545 # here to silence mypy.
546 assertIsNone: Callable[[Any], None]