Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%
402 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-16 02:07 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-16 02:07 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"]
26import contextlib
27import os
28import unittest
29from abc import ABC, abstractmethod
30from collections.abc import Iterator
31from tempfile import TemporaryDirectory
32from typing import TYPE_CHECKING, Any
34import pandas
35import yaml
36from lsst.daf.base import DateTime
37from lsst.dax.apdb import (
38 Apdb,
39 ApdbConfig,
40 ApdbInsertId,
41 ApdbSql,
42 ApdbTableData,
43 ApdbTables,
44 IncompatibleVersionError,
45 VersionTuple,
46 make_apdb,
47)
48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog
52if TYPE_CHECKING: 52 ↛ 54line 52 didn't jump to line 54, because the condition on line 52 was never true
54 class TestCaseMixin(unittest.TestCase):
55 """Base class for mixin test classes that use TestCase methods."""
57else:
59 class TestCaseMixin:
60 """Do-nothing definition of mixin base class for regular execution."""
63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
64 """Make a region to use in tests"""
65 pointing_v = UnitVector3d(*xyz)
66 fov = 0.05 # radians
67 region = Circle(pointing_v, Angle(fov / 2))
68 return region
71@contextlib.contextmanager
72def update_schema_yaml(
73 schema_file: str,
74 drop_metadata: bool = False,
75 version: str | None = None,
76) -> Iterator[str]:
77 """Update schema definition and return name of the new schema file.
79 Parameters
80 ----------
81 schema_file : `str`
82 Path for the existing YAML file with APDB schema.
83 drop_metadata : `bool`
84 If `True` then remove metadata table from the list of tables.
85 version : `str` or `None`
86 If non-empty string then set schema version to this string, if empty
87 string then remove schema version from config, if `None` - don't change
88 the version in config.
90 Yields
91 ------
92 Path for the updated configuration file.
93 """
94 with open(schema_file) as yaml_stream:
95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
96 # Edit YAML contents.
97 for schema in schemas_list:
98 # Optionally drop metadata table.
99 if drop_metadata:
100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"]
101 if version is not None:
102 if version == "":
103 del schema["version"]
104 else:
105 schema["version"] = version
107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
108 output_path = os.path.join(tmpdir, "schema.yaml")
109 with open(output_path, "w") as yaml_stream:
110 yaml.dump_all(schemas_list, stream=yaml_stream)
111 yield output_path
114class ApdbTest(TestCaseMixin, ABC):
115 """Base class for Apdb tests that can be specialized for concrete
116 implementation.
118 This can only be used as a mixin class for a unittest.TestCase and it
119 calls various assert methods.
120 """
122 time_partition_tables = False
123 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
125 fsrc_requires_id_list = False
126 """Should be set to True if getDiaForcedSources requires object IDs"""
128 use_insert_id: bool = False
129 """Set to true when support for Insert IDs is configured"""
131 allow_visit_query: bool = True
132 """Set to true when contains is implemented"""
134 # number of columns as defined in tests/config/schema.yaml
135 table_column_count = {
136 ApdbTables.DiaObject: 8,
137 ApdbTables.DiaObjectLast: 5,
138 ApdbTables.DiaSource: 10,
139 ApdbTables.DiaForcedSource: 4,
140 ApdbTables.SSObject: 3,
141 }
143 @abstractmethod
144 def make_config(self, **kwargs: Any) -> ApdbConfig:
145 """Make config class instance used in all tests."""
146 raise NotImplementedError()
148 @abstractmethod
149 def getDiaObjects_table(self) -> ApdbTables:
150 """Return type of table returned from getDiaObjects method."""
151 raise NotImplementedError()
153 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
154 """Validate catalog type and size
156 Parameters
157 ----------
158 catalog : `object`
159 Expected type of this is ``pandas.DataFrame``.
160 rows : `int`
161 Expected number of rows in a catalog.
162 table : `ApdbTables`
163 APDB table type.
164 """
165 self.assertIsInstance(catalog, pandas.DataFrame)
166 self.assertEqual(catalog.shape[0], rows)
167 self.assertEqual(catalog.shape[1], self.table_column_count[table])
169 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
170 """Validate catalog type and size
172 Parameters
173 ----------
174 catalog : `object`
175 Expected type of this is `ApdbTableData`.
176 rows : `int`
177 Expected number of rows in a catalog.
178 table : `ApdbTables`
179 APDB table type.
180 extra_columns : `int`
181 Count of additional columns expected in ``catalog``.
182 """
183 self.assertIsInstance(catalog, ApdbTableData)
184 n_rows = sum(1 for row in catalog.rows())
185 self.assertEqual(n_rows, rows)
186 # One extra column for insert_id
187 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1)
189 def test_makeSchema(self) -> None:
190 """Test for making APDB schema."""
191 config = self.make_config()
192 Apdb.makeSchema(config)
193 apdb = make_apdb(config)
195 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
196 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
199 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
201 def test_empty_gets(self) -> None:
202 """Test for getting data from empty database.
204 All get() methods should return empty results, only useful for
205 checking that code is not broken.
206 """
207 # use non-zero months for Forced/Source fetching
208 config = self.make_config()
209 Apdb.makeSchema(config)
210 apdb = make_apdb(config)
212 region = _make_region()
213 visit_time = self.visit_time
215 res: pandas.DataFrame | None
217 # get objects by region
218 res = apdb.getDiaObjects(region)
219 self.assert_catalog(res, 0, self.getDiaObjects_table())
221 # get sources by region
222 res = apdb.getDiaSources(region, None, visit_time)
223 self.assert_catalog(res, 0, ApdbTables.DiaSource)
225 res = apdb.getDiaSources(region, [], visit_time)
226 self.assert_catalog(res, 0, ApdbTables.DiaSource)
228 # get sources by object ID, non-empty object list
229 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
230 self.assert_catalog(res, 0, ApdbTables.DiaSource)
232 # get forced sources by object ID, empty object list
233 res = apdb.getDiaForcedSources(region, [], visit_time)
234 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
236 # get sources by object ID, non-empty object list
237 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
238 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
240 # test if a visit has objects/sources
241 if self.allow_visit_query:
242 res = apdb.containsVisitDetector(visit=0, detector=0)
243 self.assertFalse(res)
244 else:
245 with self.assertRaises(NotImplementedError):
246 apdb.containsVisitDetector(visit=0, detector=0)
248 # alternative method not part of the Apdb API
249 if isinstance(apdb, ApdbSql):
250 res = apdb.containsCcdVisit(1)
251 self.assertFalse(res)
253 # get sources by region
254 if self.fsrc_requires_id_list:
255 with self.assertRaises(NotImplementedError):
256 apdb.getDiaForcedSources(region, None, visit_time)
257 else:
258 apdb.getDiaForcedSources(region, None, visit_time)
259 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
261 def test_empty_gets_0months(self) -> None:
262 """Test for getting data from empty database.
264 All get() methods should return empty DataFrame or None.
265 """
266 # set read_sources_months to 0 so that Forced/Sources are None
267 config = self.make_config(read_sources_months=0, read_forced_sources_months=0)
268 Apdb.makeSchema(config)
269 apdb = make_apdb(config)
271 region = _make_region()
272 visit_time = self.visit_time
274 res: pandas.DataFrame | None
276 # get objects by region
277 res = apdb.getDiaObjects(region)
278 self.assert_catalog(res, 0, self.getDiaObjects_table())
280 # get sources by region
281 res = apdb.getDiaSources(region, None, visit_time)
282 self.assertIs(res, None)
284 # get sources by object ID, empty object list
285 res = apdb.getDiaSources(region, [], visit_time)
286 self.assertIs(res, None)
288 # get forced sources by object ID, empty object list
289 res = apdb.getDiaForcedSources(region, [], visit_time)
290 self.assertIs(res, None)
292 # test if a visit has objects/sources
293 if self.allow_visit_query:
294 res = apdb.containsVisitDetector(visit=0, detector=0)
295 self.assertFalse(res)
296 else:
297 with self.assertRaises(NotImplementedError):
298 apdb.containsVisitDetector(visit=0, detector=0)
300 # alternative method not part of the Apdb API
301 if isinstance(apdb, ApdbSql):
302 res = apdb.containsCcdVisit(1)
303 self.assertFalse(res)
305 def test_storeObjects(self) -> None:
306 """Store and retrieve DiaObjects."""
307 # don't care about sources.
308 config = self.make_config()
309 Apdb.makeSchema(config)
310 apdb = make_apdb(config)
312 region = _make_region()
313 visit_time = self.visit_time
315 # make catalog with Objects
316 catalog = makeObjectCatalog(region, 100, visit_time)
318 # store catalog
319 apdb.store(visit_time, catalog)
321 # read it back and check sizes
322 res = apdb.getDiaObjects(region)
323 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
325 # TODO: test apdb.contains with generic implementation from DM-41671
327 def test_storeObjects_empty(self) -> None:
328 """Test calling storeObject when there are no objects: see DM-43270."""
329 config = self.make_config()
330 Apdb.makeSchema(config)
331 apdb = make_apdb(config)
332 region = _make_region()
333 visit_time = self.visit_time
334 # make catalog with no Objects
335 catalog = makeObjectCatalog(region, 0, visit_time)
337 with self.assertLogs("lsst.dax.apdb.apdbSql", level="DEBUG") as cm:
338 apdb.store(visit_time, catalog)
339 self.assertIn("No objects", "\n".join(cm.output))
341 def test_storeSources(self) -> None:
342 """Store and retrieve DiaSources."""
343 config = self.make_config()
344 Apdb.makeSchema(config)
345 apdb = make_apdb(config)
347 region = _make_region()
348 visit_time = self.visit_time
350 # have to store Objects first
351 objects = makeObjectCatalog(region, 100, visit_time)
352 oids = list(objects["diaObjectId"])
353 sources = makeSourceCatalog(objects, visit_time)
355 # save the objects and sources
356 apdb.store(visit_time, objects, sources)
358 # read it back, no ID filtering
359 res = apdb.getDiaSources(region, None, visit_time)
360 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
362 # read it back and filter by ID
363 res = apdb.getDiaSources(region, oids, visit_time)
364 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
366 # read it back to get schema
367 res = apdb.getDiaSources(region, [], visit_time)
368 self.assert_catalog(res, 0, ApdbTables.DiaSource)
370 # test if a visit is present
371 # data_factory's ccdVisitId generation corresponds to (0, 0)
372 if self.allow_visit_query:
373 res = apdb.containsVisitDetector(visit=0, detector=0)
374 self.assertTrue(res)
375 else:
376 with self.assertRaises(NotImplementedError):
377 apdb.containsVisitDetector(visit=0, detector=0)
379 # alternative method not part of the Apdb API
380 if isinstance(apdb, ApdbSql):
381 res = apdb.containsCcdVisit(1)
382 self.assertTrue(res)
383 res = apdb.containsCcdVisit(42)
384 self.assertFalse(res)
386 def test_storeForcedSources(self) -> None:
387 """Store and retrieve DiaForcedSources."""
388 config = self.make_config()
389 Apdb.makeSchema(config)
390 apdb = make_apdb(config)
392 region = _make_region()
393 visit_time = self.visit_time
395 # have to store Objects first
396 objects = makeObjectCatalog(region, 100, visit_time)
397 oids = list(objects["diaObjectId"])
398 catalog = makeForcedSourceCatalog(objects, visit_time)
400 apdb.store(visit_time, objects, forced_sources=catalog)
402 # read it back and check sizes
403 res = apdb.getDiaForcedSources(region, oids, visit_time)
404 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
406 # read it back to get schema
407 res = apdb.getDiaForcedSources(region, [], visit_time)
408 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
410 # TODO: test apdb.contains with generic implementation from DM-41671
412 # alternative method not part of the Apdb API
413 if isinstance(apdb, ApdbSql):
414 res = apdb.containsCcdVisit(1)
415 self.assertTrue(res)
416 res = apdb.containsCcdVisit(42)
417 self.assertFalse(res)
419 def test_getHistory(self) -> None:
420 """Store and retrieve catalog history."""
421 # don't care about sources.
422 config = self.make_config()
423 Apdb.makeSchema(config)
424 apdb = make_apdb(config)
425 visit_time = self.visit_time
427 region1 = _make_region((1.0, 1.0, -1.0))
428 region2 = _make_region((-1.0, -1.0, -1.0))
429 nobj = 100
430 objects1 = makeObjectCatalog(region1, nobj, visit_time)
431 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
433 visits = [
434 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1),
435 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2),
436 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1),
437 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2),
438 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1),
439 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2),
440 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1),
441 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2),
442 ]
444 start_id = 0
445 for visit_time, objects in visits:
446 sources = makeSourceCatalog(objects, visit_time, start_id=start_id)
447 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id)
448 apdb.store(visit_time, objects, sources, fsources)
449 start_id += nobj
451 insert_ids = apdb.getInsertIds()
452 if not self.use_insert_id:
453 self.assertIsNone(insert_ids)
455 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"):
456 apdb.getDiaObjectsHistory([])
458 else:
459 assert insert_ids is not None
460 self.assertEqual(len(insert_ids), 8)
462 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None:
463 if n_records is None:
464 n_records = len(insert_ids) * nobj
465 res = apdb.getDiaObjectsHistory(insert_ids)
466 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
467 res = apdb.getDiaSourcesHistory(insert_ids)
468 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
469 res = apdb.getDiaForcedSourcesHistory(insert_ids)
470 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
472 # read it back and check sizes
473 _check_history(insert_ids)
474 _check_history(insert_ids[1:])
475 _check_history(insert_ids[1:-1])
476 _check_history(insert_ids[3:4])
477 _check_history([])
479 # try to remove some of those
480 deleted_ids = insert_ids[:2]
481 apdb.deleteInsertIds(deleted_ids)
483 # All queries on deleted ids should return empty set.
484 _check_history(deleted_ids, 0)
486 insert_ids = apdb.getInsertIds()
487 assert insert_ids is not None
488 self.assertEqual(len(insert_ids), 6)
490 _check_history(insert_ids)
492 def test_storeSSObjects(self) -> None:
493 """Store and retrieve SSObjects."""
494 # don't care about sources.
495 config = self.make_config()
496 Apdb.makeSchema(config)
497 apdb = make_apdb(config)
499 # make catalog with SSObjects
500 catalog = makeSSObjectCatalog(100, flags=1)
502 # store catalog
503 apdb.storeSSObjects(catalog)
505 # read it back and check sizes
506 res = apdb.getSSObjects()
507 self.assert_catalog(res, len(catalog), ApdbTables.SSObject)
509 # check that override works, make catalog with SSObjects, ID = 51-150
510 catalog = makeSSObjectCatalog(100, 51, flags=2)
511 apdb.storeSSObjects(catalog)
512 res = apdb.getSSObjects()
513 self.assert_catalog(res, 150, ApdbTables.SSObject)
514 self.assertEqual(len(res[res["flags"] == 1]), 50)
515 self.assertEqual(len(res[res["flags"] == 2]), 100)
517 def test_reassignObjects(self) -> None:
518 """Reassign DiaObjects."""
519 # don't care about sources.
520 config = self.make_config()
521 Apdb.makeSchema(config)
522 apdb = make_apdb(config)
524 region = _make_region()
525 visit_time = self.visit_time
526 objects = makeObjectCatalog(region, 100, visit_time)
527 oids = list(objects["diaObjectId"])
528 sources = makeSourceCatalog(objects, visit_time)
529 apdb.store(visit_time, objects, sources)
531 catalog = makeSSObjectCatalog(100)
532 apdb.storeSSObjects(catalog)
534 # read it back and filter by ID
535 res = apdb.getDiaSources(region, oids, visit_time)
536 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
538 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
539 res = apdb.getDiaSources(region, oids, visit_time)
540 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
542 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
543 apdb.reassignDiaSources(
544 {
545 1000: 1,
546 7: 3,
547 }
548 )
549 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
551 def test_midpointMjdTai_src(self) -> None:
552 """Test for time filtering of DiaSources."""
553 config = self.make_config()
554 Apdb.makeSchema(config)
555 apdb = make_apdb(config)
557 region = _make_region()
558 # 2021-01-01 plus 360 days is 2021-12-27
559 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
560 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
561 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
562 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
563 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
565 objects = makeObjectCatalog(region, 100, visit_time0)
566 oids = list(objects["diaObjectId"])
567 sources = makeSourceCatalog(objects, src_time1, 0)
568 apdb.store(src_time1, objects, sources)
570 sources = makeSourceCatalog(objects, src_time2, 100)
571 apdb.store(src_time2, objects, sources)
573 # reading at time of last save should read all
574 res = apdb.getDiaSources(region, oids, src_time2)
575 self.assert_catalog(res, 200, ApdbTables.DiaSource)
577 # one second before 12 months
578 res = apdb.getDiaSources(region, oids, visit_time0)
579 self.assert_catalog(res, 200, ApdbTables.DiaSource)
581 # reading at later time of last save should only read a subset
582 res = apdb.getDiaSources(region, oids, visit_time1)
583 self.assert_catalog(res, 100, ApdbTables.DiaSource)
585 # reading at later time of last save should only read a subset
586 res = apdb.getDiaSources(region, oids, visit_time2)
587 self.assert_catalog(res, 0, ApdbTables.DiaSource)
589 def test_midpointMjdTai_fsrc(self) -> None:
590 """Test for time filtering of DiaForcedSources."""
591 config = self.make_config()
592 Apdb.makeSchema(config)
593 apdb = make_apdb(config)
595 region = _make_region()
596 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI)
597 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI)
598 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI)
599 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI)
600 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI)
602 objects = makeObjectCatalog(region, 100, visit_time0)
603 oids = list(objects["diaObjectId"])
604 sources = makeForcedSourceCatalog(objects, src_time1, 1)
605 apdb.store(src_time1, objects, forced_sources=sources)
607 sources = makeForcedSourceCatalog(objects, src_time2, 2)
608 apdb.store(src_time2, objects, forced_sources=sources)
610 # reading at time of last save should read all
611 res = apdb.getDiaForcedSources(region, oids, src_time2)
612 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
614 # one second before 12 months
615 res = apdb.getDiaForcedSources(region, oids, visit_time0)
616 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
618 # reading at later time of last save should only read a subset
619 res = apdb.getDiaForcedSources(region, oids, visit_time1)
620 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
622 # reading at later time of last save should only read a subset
623 res = apdb.getDiaForcedSources(region, oids, visit_time2)
624 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
626 def test_metadata(self) -> None:
627 """Simple test for writing/reading metadata table"""
628 config = self.make_config()
629 Apdb.makeSchema(config)
630 apdb = make_apdb(config)
631 metadata = apdb.metadata
633 # APDB should write two metadata items with version numbers and a
634 # frozen JSON config.
635 self.assertFalse(metadata.empty())
636 self.assertEqual(len(list(metadata.items())), 3)
638 metadata.set("meta", "data")
639 metadata.set("data", "meta")
641 self.assertFalse(metadata.empty())
642 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")})
644 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"):
645 metadata.set("meta", "data1")
647 metadata.set("meta", "data2", force=True)
648 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")})
650 self.assertTrue(metadata.delete("meta"))
651 self.assertIsNone(metadata.get("meta"))
652 self.assertFalse(metadata.delete("meta"))
654 self.assertEqual(metadata.get("data"), "meta")
655 self.assertEqual(metadata.get("meta", "meta"), "meta")
657 def test_nometadata(self) -> None:
658 """Test case for when metadata table is missing"""
659 config = self.make_config()
660 # We expect that schema includes metadata table, drop it.
661 with update_schema_yaml(config.schema_file, drop_metadata=True) as schema_file:
662 config_nometa = self.make_config(schema_file=schema_file)
663 Apdb.makeSchema(config_nometa)
664 apdb = make_apdb(config_nometa)
665 metadata = apdb.metadata
667 self.assertTrue(metadata.empty())
668 self.assertEqual(list(metadata.items()), [])
669 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"):
670 metadata.set("meta", "data")
672 self.assertTrue(metadata.empty())
673 self.assertIsNone(metadata.get("meta"))
675 # Also check what happens when configured schema has metadata, but
676 # database is missing it. Database was initialized inside above context
677 # without metadata table, here we use schema config which includes
678 # metadata table.
679 apdb = make_apdb(config)
680 metadata = apdb.metadata
681 self.assertTrue(metadata.empty())
683 def test_schemaVersionFromYaml(self) -> None:
684 """Check version number handling for reading schema from YAML."""
685 config = self.make_config()
686 default_schema = config.schema_file
687 apdb = make_apdb(config)
688 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1))
690 with update_schema_yaml(default_schema, version="") as schema_file:
691 config = self.make_config(schema_file=schema_file)
692 apdb = make_apdb(config)
693 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0))
695 with update_schema_yaml(default_schema, version="99.0.0") as schema_file:
696 config = self.make_config(schema_file=schema_file)
697 apdb = make_apdb(config)
698 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0))
700 def test_config_freeze(self) -> None:
701 """Test that some config fields are correctly frozen in database."""
702 config = self.make_config()
703 Apdb.makeSchema(config)
705 # `use_insert_id` is the only parameter that is frozen in all
706 # implementations.
707 config.use_insert_id = not self.use_insert_id
708 apdb = make_apdb(config)
709 frozen_config = apdb.config # type: ignore[attr-defined]
710 self.assertEqual(frozen_config.use_insert_id, self.use_insert_id)
713class ApdbSchemaUpdateTest(TestCaseMixin, ABC):
714 """Base class for unit tests that verify how schema changes work."""
716 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI)
718 @abstractmethod
719 def make_config(self, **kwargs: Any) -> ApdbConfig:
720 """Make config class instance used in all tests.
722 This method should return configuration that point to the identical
723 database instance on each call (i.e. ``db_url`` must be the same,
724 which also means for sqlite it has to use on-disk storage).
725 """
726 raise NotImplementedError()
728 def test_schema_add_history(self) -> None:
729 """Check that new code can work with old schema without history
730 tables.
731 """
732 # Make schema without history tables.
733 config = self.make_config(use_insert_id=False)
734 Apdb.makeSchema(config)
735 apdb = make_apdb(config)
737 # Make APDB instance configured for history tables.
738 config = self.make_config(use_insert_id=True)
739 apdb = make_apdb(config)
741 # Try to insert something, should work OK.
742 region = _make_region()
743 visit_time = self.visit_time
745 # have to store Objects first
746 objects = makeObjectCatalog(region, 100, visit_time)
747 sources = makeSourceCatalog(objects, visit_time)
748 fsources = makeForcedSourceCatalog(objects, visit_time)
749 apdb.store(visit_time, objects, sources, fsources)
751 # There should be no history.
752 insert_ids = apdb.getInsertIds()
753 self.assertIsNone(insert_ids)
755 def test_schemaVersionCheck(self) -> None:
756 """Check version number compatibility."""
757 config = self.make_config()
758 Apdb.makeSchema(config)
759 apdb = make_apdb(config)
761 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1))
763 # Claim that schema version is now 99.0.0, must raise an exception.
764 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file:
765 config = self.make_config(schema_file=schema_file)
766 with self.assertRaises(IncompatibleVersionError):
767 apdb = make_apdb(config)