Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%
392 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-02 11:13 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-02 11:13 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"]
26import contextlib
27import os
28import unittest
29from abc import ABC, abstractmethod
30from collections.abc import Iterator
31from tempfile import TemporaryDirectory
32from typing import TYPE_CHECKING, Any
34import pandas
35import yaml
36from lsst.daf.base import DateTime
37from lsst.dax.apdb import (
38 Apdb,
39 ApdbConfig,
40 ApdbInsertId,
41 ApdbSql,
42 ApdbTableData,
43 ApdbTables,
44 IncompatibleVersionError,
45 VersionTuple,
46 make_apdb,
47)
48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog
52if TYPE_CHECKING: 52 ↛ 54line 52 didn't jump to line 54, because the condition on line 52 was never true
54 class TestCaseMixin(unittest.TestCase):
55 """Base class for mixin test classes that use TestCase methods."""
57else:
59 class TestCaseMixin:
60 """Do-nothing definition of mixin base class for regular execution."""
63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
64 """Make a region to use in tests"""
65 pointing_v = UnitVector3d(*xyz)
66 fov = 0.05 # radians
67 region = Circle(pointing_v, Angle(fov / 2))
68 return region
71@contextlib.contextmanager
72def update_schema_yaml(
73 schema_file: str,
74 drop_metadata: bool = False,
75 version: str | None = None,
76) -> Iterator[str]:
77 """Update schema definition and return name of the new schema file.
79 Parameters
80 ----------
81 schema_file : `str`
82 Path for the existing YAML file with APDB schema.
83 drop_metadata : `bool`
84 If `True` then remove metadata table from the list of tables.
85 version : `str` or `None`
86 If non-empty string then set schema version to this string, if empty
87 string then remove schema version from config, if `None` - don't change
88 the version in config.
90 Yields
91 ------
92 Path for the updated configuration file.
93 """
94 with open(schema_file) as yaml_stream:
95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
96 # Edit YAML contents.
97 for schema in schemas_list:
98 # Optionally drop metadata table.
99 if drop_metadata:
100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"]
101 if version is not None:
102 if version == "":
103 del schema["version"]
104 else:
105 schema["version"] = version
107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
108 output_path = os.path.join(tmpdir, "schema.yaml")
109 with open(output_path, "w") as yaml_stream:
110 yaml.dump_all(schemas_list, stream=yaml_stream)
111 yield output_path
114class ApdbTest(TestCaseMixin, ABC):
115 """Base class for Apdb tests that can be specialized for concrete
116 implementation.
118 This can only be used as a mixin class for a unittest.TestCase and it
119 calls various assert methods.
120 """
122 time_partition_tables = False
123 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
125 fsrc_requires_id_list = False
126 """Should be set to True if getDiaForcedSources requires object IDs"""
128 use_insert_id: bool = False
129 """Set to true when support for Insert IDs is configured"""
131 allow_visit_query: bool = True
132 """Set to true when contains is implemented"""
134 # number of columns as defined in tests/config/schema.yaml
135 table_column_count = {
136 ApdbTables.DiaObject: 8,
137 ApdbTables.DiaObjectLast: 5,
138 ApdbTables.DiaSource: 10,
139 ApdbTables.DiaForcedSource: 4,
140 ApdbTables.SSObject: 3,
141 }
143 @abstractmethod
144 def make_config(self, **kwargs: Any) -> ApdbConfig:
145 """Make config class instance used in all tests."""
146 raise NotImplementedError()
148 @abstractmethod
149 def getDiaObjects_table(self) -> ApdbTables:
150 """Return type of table returned from getDiaObjects method."""
151 raise NotImplementedError()
153 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
154 """Validate catalog type and size
156 Parameters
157 ----------
158 catalog : `object`
159 Expected type of this is ``pandas.DataFrame``.
160 rows : `int`
161 Expected number of rows in a catalog.
162 table : `ApdbTables`
163 APDB table type.
164 """
165 self.assertIsInstance(catalog, pandas.DataFrame)
166 self.assertEqual(catalog.shape[0], rows)
167 self.assertEqual(catalog.shape[1], self.table_column_count[table])
169 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
170 """Validate catalog type and size
172 Parameters
173 ----------
174 catalog : `object`
175 Expected type of this is `ApdbTableData`.
176 rows : `int`
177 Expected number of rows in a catalog.
178 table : `ApdbTables`
179 APDB table type.
180 extra_columns : `int`
181 Count of additional columns expected in ``catalog``.
182 """
183 self.assertIsInstance(catalog, ApdbTableData)
184 n_rows = sum(1 for row in catalog.rows())
185 self.assertEqual(n_rows, rows)
186 # One extra column for insert_id
187 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1)
189 def test_makeSchema(self) -> None:
190 """Test for making APDB schema."""
191 config = self.make_config()
192 Apdb.makeSchema(config)
193 apdb = make_apdb(config)
195 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
196 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
199 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
201 def test_empty_gets(self) -> None:
202 """Test for getting data from empty database.
204 All get() methods should return empty results, only useful for
205 checking that code is not broken.
206 """
207 # use non-zero months for Forced/Source fetching
208 config = self.make_config()
209 Apdb.makeSchema(config)
210 apdb = make_apdb(config)
212 region = _make_region()
213 visit_time = self.visit_time
215 res: pandas.DataFrame | None
217 # get objects by region
218 res = apdb.getDiaObjects(region)
219 self.assert_catalog(res, 0, self.getDiaObjects_table())
221 # get sources by region
222 res = apdb.getDiaSources(region, None, visit_time)
223 self.assert_catalog(res, 0, ApdbTables.DiaSource)
225 res = apdb.getDiaSources(region, [], visit_time)
226 self.assert_catalog(res, 0, ApdbTables.DiaSource)
228 # get sources by object ID, non-empty object list
229 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
230 self.assert_catalog(res, 0, ApdbTables.DiaSource)
232 # get forced sources by object ID, empty object list
233 res = apdb.getDiaForcedSources(region, [], visit_time)
234 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
236 # get sources by object ID, non-empty object list
237 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
238 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
240 # test if a visit has objects/sources
241 if self.allow_visit_query:
242 res = apdb.containsVisitDetector(visit=0, detector=0)
243 self.assertFalse(res)
244 else:
245 with self.assertRaises(NotImplementedError):
246 apdb.containsVisitDetector(visit=0, detector=0)
248 # alternative method not part of the Apdb API
249 if isinstance(apdb, ApdbSql):
250 res = apdb.containsCcdVisit(1)
251 self.assertFalse(res)
253 # get sources by region
254 if self.fsrc_requires_id_list:
255 with self.assertRaises(NotImplementedError):
256 apdb.getDiaForcedSources(region, None, visit_time)
257 else:
258 apdb.getDiaForcedSources(region, None, visit_time)
259 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
261 def test_empty_gets_0months(self) -> None:
262 """Test for getting data from empty database.
264 All get() methods should return empty DataFrame or None.
265 """
266 # set read_sources_months to 0 so that Forced/Sources are None
267 config = self.make_config(read_sources_months=0, read_forced_sources_months=0)
268 Apdb.makeSchema(config)
269 apdb = make_apdb(config)
271 region = _make_region()
272 visit_time = self.visit_time
274 res: pandas.DataFrame | None
276 # get objects by region
277 res = apdb.getDiaObjects(region)
278 self.assert_catalog(res, 0, self.getDiaObjects_table())
280 # get sources by region
281 res = apdb.getDiaSources(region, None, visit_time)
282 self.assertIs(res, None)
284 # get sources by object ID, empty object list
285 res = apdb.getDiaSources(region, [], visit_time)
286 self.assertIs(res, None)
288 # get forced sources by object ID, empty object list
289 res = apdb.getDiaForcedSources(region, [], visit_time)
290 self.assertIs(res, None)
292 # test if a visit has objects/sources
293 if self.allow_visit_query:
294 res = apdb.containsVisitDetector(visit=0, detector=0)
295 self.assertFalse(res)
296 else:
297 with self.assertRaises(NotImplementedError):
298 apdb.containsVisitDetector(visit=0, detector=0)
300 # alternative method not part of the Apdb API
301 if isinstance(apdb, ApdbSql):
302 res = apdb.containsCcdVisit(1)
303 self.assertFalse(res)
305 def test_storeObjects(self) -> None:
306 """Store and retrieve DiaObjects."""
307 # don't care about sources.
308 config = self.make_config()
309 Apdb.makeSchema(config)
310 apdb = make_apdb(config)
312 region = _make_region()
313 visit_time = self.visit_time
315 # make catalog with Objects
316 catalog = makeObjectCatalog(region, 100, visit_time)
318 # store catalog
319 apdb.store(visit_time, catalog)
321 # read it back and check sizes
322 res = apdb.getDiaObjects(region)
323 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
325 # TODO: test apdb.contains with generic implementation from DM-41671
327 def test_storeSources(self) -> None:
328 """Store and retrieve DiaSources."""
329 config = self.make_config()
330 Apdb.makeSchema(config)
331 apdb = make_apdb(config)
333 region = _make_region()
334 visit_time = self.visit_time
336 # have to store Objects first
337 objects = makeObjectCatalog(region, 100, visit_time)
338 oids = list(objects["diaObjectId"])
339 sources = makeSourceCatalog(objects, visit_time)
341 # save the objects and sources
342 apdb.store(visit_time, objects, sources)
344 # read it back, no ID filtering
345 res = apdb.getDiaSources(region, None, visit_time)
346 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
348 # read it back and filter by ID
349 res = apdb.getDiaSources(region, oids, visit_time)
350 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
352 # read it back to get schema
353 res = apdb.getDiaSources(region, [], visit_time)
354 self.assert_catalog(res, 0, ApdbTables.DiaSource)
356 # test if a visit is present
357 # data_factory's ccdVisitId generation corresponds to (0, 0)
358 if self.allow_visit_query:
359 res = apdb.containsVisitDetector(visit=0, detector=0)
360 self.assertTrue(res)
361 else:
362 with self.assertRaises(NotImplementedError):
363 apdb.containsVisitDetector(visit=0, detector=0)
365 # alternative method not part of the Apdb API
366 if isinstance(apdb, ApdbSql):
367 res = apdb.containsCcdVisit(1)
368 self.assertTrue(res)
369 res = apdb.containsCcdVisit(42)
370 self.assertFalse(res)
372 def test_storeForcedSources(self) -> None:
373 """Store and retrieve DiaForcedSources."""
374 config = self.make_config()
375 Apdb.makeSchema(config)
376 apdb = make_apdb(config)
378 region = _make_region()
379 visit_time = self.visit_time
381 # have to store Objects first
382 objects = makeObjectCatalog(region, 100, visit_time)
383 oids = list(objects["diaObjectId"])
384 catalog = makeForcedSourceCatalog(objects, visit_time)
386 apdb.store(visit_time, objects, forced_sources=catalog)
388 # read it back and check sizes
389 res = apdb.getDiaForcedSources(region, oids, visit_time)
390 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
392 # read it back to get schema
393 res = apdb.getDiaForcedSources(region, [], visit_time)
394 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
396 # TODO: test apdb.contains with generic implementation from DM-41671
398 # alternative method not part of the Apdb API
399 if isinstance(apdb, ApdbSql):
400 res = apdb.containsCcdVisit(1)
401 self.assertTrue(res)
402 res = apdb.containsCcdVisit(42)
403 self.assertFalse(res)
405 def test_getHistory(self) -> None:
406 """Store and retrieve catalog history."""
407 # don't care about sources.
408 config = self.make_config()
409 Apdb.makeSchema(config)
410 apdb = make_apdb(config)
411 visit_time = self.visit_time
413 region1 = _make_region((1.0, 1.0, -1.0))
414 region2 = _make_region((-1.0, -1.0, -1.0))
415 nobj = 100
416 objects1 = makeObjectCatalog(region1, nobj, visit_time)
417 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
419 visits = [
420 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1),
421 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2),
422 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1),
423 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2),
424 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1),
425 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2),
426 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1),
427 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2),
428 ]
430 start_id = 0
431 for visit_time, objects in visits:
432 sources = makeSourceCatalog(objects, visit_time, start_id=start_id)
433 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id)
434 apdb.store(visit_time, objects, sources, fsources)
435 start_id += nobj
437 insert_ids = apdb.getInsertIds()
438 if not self.use_insert_id:
439 self.assertIsNone(insert_ids)
441 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"):
442 apdb.getDiaObjectsHistory([])
444 else:
445 assert insert_ids is not None
446 self.assertEqual(len(insert_ids), 8)
448 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None:
449 if n_records is None:
450 n_records = len(insert_ids) * nobj
451 res = apdb.getDiaObjectsHistory(insert_ids)
452 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
453 res = apdb.getDiaSourcesHistory(insert_ids)
454 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
455 res = apdb.getDiaForcedSourcesHistory(insert_ids)
456 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
458 # read it back and check sizes
459 _check_history(insert_ids)
460 _check_history(insert_ids[1:])
461 _check_history(insert_ids[1:-1])
462 _check_history(insert_ids[3:4])
463 _check_history([])
465 # try to remove some of those
466 deleted_ids = insert_ids[:2]
467 apdb.deleteInsertIds(deleted_ids)
469 # All queries on deleted ids should return empty set.
470 _check_history(deleted_ids, 0)
472 insert_ids = apdb.getInsertIds()
473 assert insert_ids is not None
474 self.assertEqual(len(insert_ids), 6)
476 _check_history(insert_ids)
478 def test_storeSSObjects(self) -> None:
479 """Store and retrieve SSObjects."""
480 # don't care about sources.
481 config = self.make_config()
482 Apdb.makeSchema(config)
483 apdb = make_apdb(config)
485 # make catalog with SSObjects
486 catalog = makeSSObjectCatalog(100, flags=1)
488 # store catalog
489 apdb.storeSSObjects(catalog)
491 # read it back and check sizes
492 res = apdb.getSSObjects()
493 self.assert_catalog(res, len(catalog), ApdbTables.SSObject)
495 # check that override works, make catalog with SSObjects, ID = 51-150
496 catalog = makeSSObjectCatalog(100, 51, flags=2)
497 apdb.storeSSObjects(catalog)
498 res = apdb.getSSObjects()
499 self.assert_catalog(res, 150, ApdbTables.SSObject)
500 self.assertEqual(len(res[res["flags"] == 1]), 50)
501 self.assertEqual(len(res[res["flags"] == 2]), 100)
503 def test_reassignObjects(self) -> None:
504 """Reassign DiaObjects."""
505 # don't care about sources.
506 config = self.make_config()
507 Apdb.makeSchema(config)
508 apdb = make_apdb(config)
510 region = _make_region()
511 visit_time = self.visit_time
512 objects = makeObjectCatalog(region, 100, visit_time)
513 oids = list(objects["diaObjectId"])
514 sources = makeSourceCatalog(objects, visit_time)
515 apdb.store(visit_time, objects, sources)
517 catalog = makeSSObjectCatalog(100)
518 apdb.storeSSObjects(catalog)
520 # read it back and filter by ID
521 res = apdb.getDiaSources(region, oids, visit_time)
522 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
524 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
525 res = apdb.getDiaSources(region, oids, visit_time)
526 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
528 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
529 apdb.reassignDiaSources(
530 {
531 1000: 1,
532 7: 3,
533 }
534 )
535 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
537 def test_midpointMjdTai_src(self) -> None:
538 """Test for time filtering of DiaSources."""
539 config = self.make_config()
540 Apdb.makeSchema(config)
541 apdb = make_apdb(config)
543 region = _make_region()
544 # 2021-01-01 plus 360 days is 2021-12-27
545 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
546 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
547 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
548 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
549 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
551 objects = makeObjectCatalog(region, 100, visit_time0)
552 oids = list(objects["diaObjectId"])
553 sources = makeSourceCatalog(objects, src_time1, 0)
554 apdb.store(src_time1, objects, sources)
556 sources = makeSourceCatalog(objects, src_time2, 100)
557 apdb.store(src_time2, objects, sources)
559 # reading at time of last save should read all
560 res = apdb.getDiaSources(region, oids, src_time2)
561 self.assert_catalog(res, 200, ApdbTables.DiaSource)
563 # one second before 12 months
564 res = apdb.getDiaSources(region, oids, visit_time0)
565 self.assert_catalog(res, 200, ApdbTables.DiaSource)
567 # reading at later time of last save should only read a subset
568 res = apdb.getDiaSources(region, oids, visit_time1)
569 self.assert_catalog(res, 100, ApdbTables.DiaSource)
571 # reading at later time of last save should only read a subset
572 res = apdb.getDiaSources(region, oids, visit_time2)
573 self.assert_catalog(res, 0, ApdbTables.DiaSource)
575 def test_midpointMjdTai_fsrc(self) -> None:
576 """Test for time filtering of DiaForcedSources."""
577 config = self.make_config()
578 Apdb.makeSchema(config)
579 apdb = make_apdb(config)
581 region = _make_region()
582 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
583 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
584 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
585 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
586 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
588 objects = makeObjectCatalog(region, 100, visit_time0)
589 oids = list(objects["diaObjectId"])
590 sources = makeForcedSourceCatalog(objects, src_time1, 1)
591 apdb.store(src_time1, objects, forced_sources=sources)
593 sources = makeForcedSourceCatalog(objects, src_time2, 2)
594 apdb.store(src_time2, objects, forced_sources=sources)
596 # reading at time of last save should read all
597 res = apdb.getDiaForcedSources(region, oids, src_time2)
598 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
600 # one second before 12 months
601 res = apdb.getDiaForcedSources(region, oids, visit_time0)
602 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
604 # reading at later time of last save should only read a subset
605 res = apdb.getDiaForcedSources(region, oids, visit_time1)
606 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
608 # reading at later time of last save should only read a subset
609 res = apdb.getDiaForcedSources(region, oids, visit_time2)
610 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
612 def test_metadata(self) -> None:
613 """Simple test for writing/reading metadata table"""
614 config = self.make_config()
615 Apdb.makeSchema(config)
616 apdb = make_apdb(config)
617 metadata = apdb.metadata
619 # APDB should write two metadata items with version numbers and a
620 # frozen JSON config.
621 self.assertFalse(metadata.empty())
622 self.assertEqual(len(list(metadata.items())), 3)
624 metadata.set("meta", "data")
625 metadata.set("data", "meta")
627 self.assertFalse(metadata.empty())
628 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")})
630 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"):
631 metadata.set("meta", "data1")
633 metadata.set("meta", "data2", force=True)
634 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")})
636 self.assertTrue(metadata.delete("meta"))
637 self.assertIsNone(metadata.get("meta"))
638 self.assertFalse(metadata.delete("meta"))
640 self.assertEqual(metadata.get("data"), "meta")
641 self.assertEqual(metadata.get("meta", "meta"), "meta")
643 def test_nometadata(self) -> None:
644 """Test case for when metadata table is missing"""
645 config = self.make_config()
646 # We expect that schema includes metadata table, drop it.
647 with update_schema_yaml(config.schema_file, drop_metadata=True) as schema_file:
648 config_nometa = self.make_config(schema_file=schema_file)
649 Apdb.makeSchema(config_nometa)
650 apdb = make_apdb(config_nometa)
651 metadata = apdb.metadata
653 self.assertTrue(metadata.empty())
654 self.assertEqual(list(metadata.items()), [])
655 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"):
656 metadata.set("meta", "data")
658 self.assertTrue(metadata.empty())
659 self.assertIsNone(metadata.get("meta"))
661 # Also check what happens when configured schema has metadata, but
662 # database is missing it. Database was initialized inside above context
663 # without metadata table, here we use schema config which includes
664 # metadata table.
665 apdb = make_apdb(config)
666 metadata = apdb.metadata
667 self.assertTrue(metadata.empty())
669 def test_schemaVersionFromYaml(self) -> None:
670 """Check version number handling for reading schema from YAML."""
671 config = self.make_config()
672 default_schema = config.schema_file
673 apdb = make_apdb(config)
674 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1))
676 with update_schema_yaml(default_schema, version="") as schema_file:
677 config = self.make_config(schema_file=schema_file)
678 apdb = make_apdb(config)
679 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0))
681 with update_schema_yaml(default_schema, version="99.0.0") as schema_file:
682 config = self.make_config(schema_file=schema_file)
683 apdb = make_apdb(config)
684 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0))
686 def test_config_freeze(self) -> None:
687 """Test that some config fields are correctly frozen in database."""
688 config = self.make_config()
689 Apdb.makeSchema(config)
691 # `use_insert_id` is the only parameter that is frozen in all
692 # implementations.
693 config.use_insert_id = not self.use_insert_id
694 apdb = make_apdb(config)
695 frozen_config = apdb.config # type: ignore[attr-defined]
696 self.assertEqual(frozen_config.use_insert_id, self.use_insert_id)
699class ApdbSchemaUpdateTest(TestCaseMixin, ABC):
700 """Base class for unit tests that verify how schema changes work."""
702 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
704 @abstractmethod
705 def make_config(self, **kwargs: Any) -> ApdbConfig:
706 """Make config class instance used in all tests.
708 This method should return configuration that point to the identical
709 database instance on each call (i.e. ``db_url`` must be the same,
710 which also means for sqlite it has to use on-disk storage).
711 """
712 raise NotImplementedError()
714 def test_schema_add_history(self) -> None:
715 """Check that new code can work with old schema without history
716 tables.
717 """
718 # Make schema without history tables.
719 config = self.make_config(use_insert_id=False)
720 Apdb.makeSchema(config)
721 apdb = make_apdb(config)
723 # Make APDB instance configured for history tables.
724 config = self.make_config(use_insert_id=True)
725 apdb = make_apdb(config)
727 # Try to insert something, should work OK.
728 region = _make_region()
729 visit_time = self.visit_time
731 # have to store Objects first
732 objects = makeObjectCatalog(region, 100, visit_time)
733 sources = makeSourceCatalog(objects, visit_time)
734 fsources = makeForcedSourceCatalog(objects, visit_time)
735 apdb.store(visit_time, objects, sources, fsources)
737 # There should be no history.
738 insert_ids = apdb.getInsertIds()
739 self.assertIsNone(insert_ids)
741 def test_schemaVersionCheck(self) -> None:
742 """Check version number compatibility."""
743 config = self.make_config()
744 Apdb.makeSchema(config)
745 apdb = make_apdb(config)
747 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1))
749 # Claim that schema version is now 99.0.0, must raise an exception.
750 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file:
751 config = self.make_config(schema_file=schema_file)
752 with self.assertRaises(IncompatibleVersionError):
753 apdb = make_apdb(config)