Coverage for python/lsst/dax/apdb/tests/_apdb.py: 17%
276 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-12 09:46 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-12 09:46 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest"]
26from abc import ABC, abstractmethod
27from collections.abc import Callable
28from typing import TYPE_CHECKING, Any, ContextManager, Optional
30import pandas
31from lsst.daf.base import DateTime
32from lsst.dax.apdb import ApdbConfig, ApdbInsertId, ApdbTableData, ApdbTables, make_apdb
33from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
35from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog
38def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
39 """Make a region to use in tests"""
40 pointing_v = UnitVector3d(*xyz)
41 fov = 0.05 # radians
42 region = Circle(pointing_v, Angle(fov / 2))
43 return region
46class ApdbTest(ABC):
47 """Base class for Apdb tests that can be specialized for concrete
48 implementation.
50 This can only be used as a mixin class for a unittest.TestCase and it
51 calls various assert methods.
52 """
54 time_partition_tables = False
55 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
57 fsrc_requires_id_list = False
58 """Should be set to True if getDiaForcedSources requires object IDs"""
60 use_insert_id: bool = False
61 """Set to true when support for Insert IDs is configured"""
63 # number of columns as defined in tests/config/schema.yaml
64 table_column_count = {
65 ApdbTables.DiaObject: 8,
66 ApdbTables.DiaObjectLast: 5,
67 ApdbTables.DiaSource: 10,
68 ApdbTables.DiaForcedSource: 4,
69 ApdbTables.SSObject: 3,
70 }
72 @abstractmethod
73 def make_config(self, **kwargs: Any) -> ApdbConfig:
74 """Make config class instance used in all tests."""
75 raise NotImplementedError()
77 @abstractmethod
78 def getDiaObjects_table(self) -> ApdbTables:
79 """Return type of table returned from getDiaObjects method."""
80 raise NotImplementedError()
82 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
83 """Validate catalog type and size
85 Parameters
86 ----------
87 catalog : `object`
88 Expected type of this is ``pandas.DataFrame``.
89 rows : `int`
90 Expected number of rows in a catalog.
91 table : `ApdbTables`
92 APDB table type.
93 """
94 self.assertIsInstance(catalog, pandas.DataFrame)
95 self.assertEqual(catalog.shape[0], rows)
96 self.assertEqual(catalog.shape[1], self.table_column_count[table])
98 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
99 """Validate catalog type and size
101 Parameters
102 ----------
103 catalog : `object`
104 Expected type of this is `ApdbTableData`.
105 rows : `int`
106 Expected number of rows in a catalog.
107 table : `ApdbTables`
108 APDB table type.
109 extra_columns : `int`
110 Count of additional columns expected in ``catalog``.
111 """
112 self.assertIsInstance(catalog, ApdbTableData)
113 n_rows = sum(1 for row in catalog.rows())
114 self.assertEqual(n_rows, rows)
115 # One extra column for insert_id
116 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1)
118 def test_makeSchema(self) -> None:
119 """Test for makeing APDB schema."""
120 config = self.make_config()
121 apdb = make_apdb(config)
123 apdb.makeSchema()
124 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
125 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
126 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
127 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
129 def test_empty_gets(self) -> None:
130 """Test for getting data from empty database.
132 All get() methods should return empty results, only useful for
133 checking that code is not broken.
134 """
135 # use non-zero months for Forced/Source fetching
136 config = self.make_config()
137 apdb = make_apdb(config)
138 apdb.makeSchema()
140 region = _make_region()
141 visit_time = self.visit_time
143 res: Optional[pandas.DataFrame]
145 # get objects by region
146 res = apdb.getDiaObjects(region)
147 self.assert_catalog(res, 0, self.getDiaObjects_table())
149 # get sources by region
150 res = apdb.getDiaSources(region, None, visit_time)
151 self.assert_catalog(res, 0, ApdbTables.DiaSource)
153 res = apdb.getDiaSources(region, [], visit_time)
154 self.assert_catalog(res, 0, ApdbTables.DiaSource)
156 # get sources by object ID, non-empty object list
157 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
158 self.assert_catalog(res, 0, ApdbTables.DiaSource)
160 # get forced sources by object ID, empty object list
161 res = apdb.getDiaForcedSources(region, [], visit_time)
162 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
164 # get sources by object ID, non-empty object list
165 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
166 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
168 # get sources by region
169 if self.fsrc_requires_id_list:
170 with self.assertRaises(NotImplementedError):
171 apdb.getDiaForcedSources(region, None, visit_time)
172 else:
173 apdb.getDiaForcedSources(region, None, visit_time)
174 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
176 def test_empty_gets_0months(self) -> None:
177 """Test for getting data from empty database.
179 All get() methods should return empty DataFrame or None.
180 """
181 # set read_sources_months to 0 so that Forced/Sources are None
182 config = self.make_config(read_sources_months=0, read_forced_sources_months=0)
183 apdb = make_apdb(config)
184 apdb.makeSchema()
186 region = _make_region()
187 visit_time = self.visit_time
189 res: Optional[pandas.DataFrame]
191 # get objects by region
192 res = apdb.getDiaObjects(region)
193 self.assert_catalog(res, 0, self.getDiaObjects_table())
195 # get sources by region
196 res = apdb.getDiaSources(region, None, visit_time)
197 self.assertIs(res, None)
199 # get sources by object ID, empty object list
200 res = apdb.getDiaSources(region, [], visit_time)
201 self.assertIs(res, None)
203 # get forced sources by object ID, empty object list
204 res = apdb.getDiaForcedSources(region, [], visit_time)
205 self.assertIs(res, None)
207 def test_storeObjects(self) -> None:
208 """Store and retrieve DiaObjects."""
209 # don't care about sources.
210 config = self.make_config()
211 apdb = make_apdb(config)
212 apdb.makeSchema()
214 region = _make_region()
215 visit_time = self.visit_time
217 # make catalog with Objects
218 catalog = makeObjectCatalog(region, 100, visit_time)
220 # store catalog
221 apdb.store(visit_time, catalog)
223 # read it back and check sizes
224 res = apdb.getDiaObjects(region)
225 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
227 def test_storeSources(self) -> None:
228 """Store and retrieve DiaSources."""
229 config = self.make_config()
230 apdb = make_apdb(config)
231 apdb.makeSchema()
233 region = _make_region()
234 visit_time = self.visit_time
236 # have to store Objects first
237 objects = makeObjectCatalog(region, 100, visit_time)
238 oids = list(objects["diaObjectId"])
239 sources = makeSourceCatalog(objects, visit_time)
241 # save the objects and sources
242 apdb.store(visit_time, objects, sources)
244 # read it back, no ID filtering
245 res = apdb.getDiaSources(region, None, visit_time)
246 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
248 # read it back and filter by ID
249 res = apdb.getDiaSources(region, oids, visit_time)
250 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
252 # read it back to get schema
253 res = apdb.getDiaSources(region, [], visit_time)
254 self.assert_catalog(res, 0, ApdbTables.DiaSource)
256 def test_storeForcedSources(self) -> None:
257 """Store and retrieve DiaForcedSources."""
258 config = self.make_config()
259 apdb = make_apdb(config)
260 apdb.makeSchema()
262 region = _make_region()
263 visit_time = self.visit_time
265 # have to store Objects first
266 objects = makeObjectCatalog(region, 100, visit_time)
267 oids = list(objects["diaObjectId"])
268 catalog = makeForcedSourceCatalog(objects, visit_time)
270 apdb.store(visit_time, objects, forced_sources=catalog)
272 # read it back and check sizes
273 res = apdb.getDiaForcedSources(region, oids, visit_time)
274 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
276 # read it back to get schema
277 res = apdb.getDiaForcedSources(region, [], visit_time)
278 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
280 def test_getHistory(self) -> None:
281 """Store and retrieve catalog history."""
282 # don't care about sources.
283 config = self.make_config()
284 apdb = make_apdb(config)
285 apdb.makeSchema()
286 visit_time = self.visit_time
288 region1 = _make_region((1.0, 1.0, -1.0))
289 region2 = _make_region((-1.0, -1.0, -1.0))
290 nobj = 100
291 objects1 = makeObjectCatalog(region1, nobj, visit_time)
292 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
294 visits = [
295 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1),
296 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2),
297 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1),
298 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2),
299 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1),
300 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2),
301 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1),
302 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2),
303 ]
305 start_id = 0
306 for visit_time, objects in visits:
307 sources = makeSourceCatalog(objects, visit_time, start_id=start_id)
308 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id)
309 apdb.store(visit_time, objects, sources, fsources)
310 start_id += nobj
312 insert_ids = apdb.getInsertIds()
313 if not self.use_insert_id:
314 self.assertIsNone(insert_ids)
316 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"):
317 apdb.getDiaObjectsHistory([])
319 else:
320 assert insert_ids is not None
321 self.assertEqual(len(insert_ids), 8)
323 def _check_history(insert_ids: list[ApdbInsertId]) -> None:
324 n_records = len(insert_ids) * nobj
325 res = apdb.getDiaObjectsHistory(insert_ids)
326 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
327 res = apdb.getDiaSourcesHistory(insert_ids)
328 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
329 res = apdb.getDiaForcedSourcesHistory(insert_ids)
330 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
332 # read it back and check sizes
333 _check_history(insert_ids)
334 _check_history(insert_ids[1:])
335 _check_history(insert_ids[1:-1])
336 _check_history(insert_ids[3:4])
337 _check_history([])
339 # try to remove some of those
340 apdb.deleteInsertIds(insert_ids[:2])
341 insert_ids = apdb.getInsertIds()
342 assert insert_ids is not None
343 self.assertEqual(len(insert_ids), 6)
345 _check_history(insert_ids)
347 def test_storeSSObjects(self) -> None:
348 """Store and retrieve SSObjects."""
349 # don't care about sources.
350 config = self.make_config()
351 apdb = make_apdb(config)
352 apdb.makeSchema()
354 # make catalog with SSObjects
355 catalog = makeSSObjectCatalog(100, flags=1)
357 # store catalog
358 apdb.storeSSObjects(catalog)
360 # read it back and check sizes
361 res = apdb.getSSObjects()
362 self.assert_catalog(res, len(catalog), ApdbTables.SSObject)
364 # check that override works, make catalog with SSObjects, ID = 51-150
365 catalog = makeSSObjectCatalog(100, 51, flags=2)
366 apdb.storeSSObjects(catalog)
367 res = apdb.getSSObjects()
368 self.assert_catalog(res, 150, ApdbTables.SSObject)
369 self.assertEqual(len(res[res["flags"] == 1]), 50)
370 self.assertEqual(len(res[res["flags"] == 2]), 100)
372 def test_reassignObjects(self) -> None:
373 """Reassign DiaObjects."""
374 # don't care about sources.
375 config = self.make_config()
376 apdb = make_apdb(config)
377 apdb.makeSchema()
379 region = _make_region()
380 visit_time = self.visit_time
381 objects = makeObjectCatalog(region, 100, visit_time)
382 oids = list(objects["diaObjectId"])
383 sources = makeSourceCatalog(objects, visit_time)
384 apdb.store(visit_time, objects, sources)
386 catalog = makeSSObjectCatalog(100)
387 apdb.storeSSObjects(catalog)
389 # read it back and filter by ID
390 res = apdb.getDiaSources(region, oids, visit_time)
391 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
393 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
394 res = apdb.getDiaSources(region, oids, visit_time)
395 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
397 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
398 apdb.reassignDiaSources(
399 {
400 1000: 1,
401 7: 3,
402 }
403 )
404 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
406 def test_midpointMjdTai_src(self) -> None:
407 """Test for time filtering of DiaSources."""
408 config = self.make_config()
409 apdb = make_apdb(config)
410 apdb.makeSchema()
412 region = _make_region()
413 # 2021-01-01 plus 360 days is 2021-12-27
414 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
415 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
416 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
417 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
418 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
420 objects = makeObjectCatalog(region, 100, visit_time0)
421 oids = list(objects["diaObjectId"])
422 sources = makeSourceCatalog(objects, src_time1, 0)
423 apdb.store(src_time1, objects, sources)
425 sources = makeSourceCatalog(objects, src_time2, 100)
426 apdb.store(src_time2, objects, sources)
428 # reading at time of last save should read all
429 res = apdb.getDiaSources(region, oids, src_time2)
430 self.assert_catalog(res, 200, ApdbTables.DiaSource)
432 # one second before 12 months
433 res = apdb.getDiaSources(region, oids, visit_time0)
434 self.assert_catalog(res, 200, ApdbTables.DiaSource)
436 # reading at later time of last save should only read a subset
437 res = apdb.getDiaSources(region, oids, visit_time1)
438 self.assert_catalog(res, 100, ApdbTables.DiaSource)
440 # reading at later time of last save should only read a subset
441 res = apdb.getDiaSources(region, oids, visit_time2)
442 self.assert_catalog(res, 0, ApdbTables.DiaSource)
444 def test_midpointMjdTai_fsrc(self) -> None:
445 """Test for time filtering of DiaForcedSources."""
446 config = self.make_config()
447 apdb = make_apdb(config)
448 apdb.makeSchema()
450 region = _make_region()
451 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
452 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
453 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
454 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
455 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
457 objects = makeObjectCatalog(region, 100, visit_time0)
458 oids = list(objects["diaObjectId"])
459 sources = makeForcedSourceCatalog(objects, src_time1, 1)
460 apdb.store(src_time1, objects, forced_sources=sources)
462 sources = makeForcedSourceCatalog(objects, src_time2, 2)
463 apdb.store(src_time2, objects, forced_sources=sources)
465 # reading at time of last save should read all
466 res = apdb.getDiaForcedSources(region, oids, src_time2)
467 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
469 # one second before 12 months
470 res = apdb.getDiaForcedSources(region, oids, visit_time0)
471 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
473 # reading at later time of last save should only read a subset
474 res = apdb.getDiaForcedSources(region, oids, visit_time1)
475 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
477 # reading at later time of last save should only read a subset
478 res = apdb.getDiaForcedSources(region, oids, visit_time2)
479 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
481 if TYPE_CHECKING: 481 ↛ 484line 481 didn't jump to line 484, because the condition on line 481 was never true
482 # This is a mixin class, some methods from unittest.TestCase declared
483 # here to silence mypy.
484 assertEqual: Callable[[Any, Any], None]
485 assertIs: Callable[[Any, Any], None]
486 assertIsInstance: Callable[[Any, Any], None]
487 assertIsNone: Callable[[Any], None]
488 assertIsNotNone: Callable[[Any], None]
489 assertRaises: Callable[[Any], ContextManager]
490 assertRaisesRegex: Callable[[Any, Any], ContextManager]
493class ApdbSchemaUpdateTest(ABC):
494 """Base class for unit tests that verify how schema changes work."""
496 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
498 @abstractmethod
499 def make_config(self, **kwargs: Any) -> ApdbConfig:
500 """Make config class instance used in all tests.
502 This method should return configuration that point to the identical
503 database instance on each call (i.e. ``db_url`` must be the same,
504 which also means for sqlite it has to use on-disk storage).
505 """
506 raise NotImplementedError()
508 def test_schema_add_history(self) -> None:
509 """Check that new code can work with old schema without history
510 tables.
511 """
512 # Make schema without history tables.
513 config = self.make_config(use_insert_id=False)
514 apdb = make_apdb(config)
515 apdb.makeSchema()
517 # Make APDB instance configured for history tables.
518 config = self.make_config(use_insert_id=True)
519 apdb = make_apdb(config)
521 # Try to insert something, should work OK.
522 region = _make_region()
523 visit_time = self.visit_time
525 # have to store Objects first
526 objects = makeObjectCatalog(region, 100, visit_time)
527 sources = makeSourceCatalog(objects, visit_time)
528 fsources = makeForcedSourceCatalog(objects, visit_time)
529 apdb.store(visit_time, objects, sources, fsources)
531 # There should be no history.
532 insert_ids = apdb.getInsertIds()
533 self.assertIsNone(insert_ids)
535 if TYPE_CHECKING: 535 ↛ 538line 535 didn't jump to line 538, because the condition on line 535 was never true
536 # This is a mixin class, some methods from unittest.TestCase declared
537 # here to silence mypy.
538 assertIsNone: Callable[[Any], None]