Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%
391 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 09:59 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 09:59 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"]
26import contextlib
27import os
28import tempfile
29import unittest
30from abc import ABC, abstractmethod
31from collections.abc import Iterator
32from tempfile import TemporaryDirectory
33from typing import TYPE_CHECKING, Any
35import astropy.time
36import pandas
37import yaml
38from lsst.dax.apdb import (
39 Apdb,
40 ApdbConfig,
41 ApdbInsertId,
42 ApdbSql,
43 ApdbTableData,
44 ApdbTables,
45 IncompatibleVersionError,
46 VersionTuple,
47)
48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog
52if TYPE_CHECKING:
54 class TestCaseMixin(unittest.TestCase):
55 """Base class for mixin test classes that use TestCase methods."""
57else:
59 class TestCaseMixin:
60 """Do-nothing definition of mixin base class for regular execution."""
63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
64 """Make a region to use in tests"""
65 pointing_v = UnitVector3d(*xyz)
66 fov = 0.05 # radians
67 region = Circle(pointing_v, Angle(fov / 2))
68 return region
71@contextlib.contextmanager
72def update_schema_yaml(
73 schema_file: str,
74 drop_metadata: bool = False,
75 version: str | None = None,
76) -> Iterator[str]:
77 """Update schema definition and return name of the new schema file.
79 Parameters
80 ----------
81 schema_file : `str`
82 Path for the existing YAML file with APDB schema.
83 drop_metadata : `bool`
84 If `True` then remove metadata table from the list of tables.
85 version : `str` or `None`
86 If non-empty string then set schema version to this string, if empty
87 string then remove schema version from config, if `None` - don't change
88 the version in config.
90 Yields
91 ------
92 Path for the updated configuration file.
93 """
94 with open(schema_file) as yaml_stream:
95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
96 # Edit YAML contents.
97 for schema in schemas_list:
98 # Optionally drop metadata table.
99 if drop_metadata:
100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"]
101 if version is not None:
102 if version == "":
103 del schema["version"]
104 else:
105 schema["version"] = version
107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
108 output_path = os.path.join(tmpdir, "schema.yaml")
109 with open(output_path, "w") as yaml_stream:
110 yaml.dump_all(schemas_list, stream=yaml_stream)
111 yield output_path
114class ApdbTest(TestCaseMixin, ABC):
115 """Base class for Apdb tests that can be specialized for concrete
116 implementation.
118 This can only be used as a mixin class for a unittest.TestCase and it
119 calls various assert methods.
120 """
122 time_partition_tables = False
123 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
125 fsrc_requires_id_list = False
126 """Should be set to True if getDiaForcedSources requires object IDs"""
128 use_insert_id: bool = False
129 """Set to true when support for Insert IDs is configured"""
131 allow_visit_query: bool = True
132 """Set to true when contains is implemented"""
134 schema_path: str
135 """Location of the Felis schema file."""
137 # number of columns as defined in tests/config/schema.yaml
138 table_column_count = {
139 ApdbTables.DiaObject: 8,
140 ApdbTables.DiaObjectLast: 5,
141 ApdbTables.DiaSource: 10,
142 ApdbTables.DiaForcedSource: 4,
143 ApdbTables.SSObject: 3,
144 }
146 @abstractmethod
147 def make_instance(self, **kwargs: Any) -> ApdbConfig:
148 """Make database instance and return configuration for it."""
149 raise NotImplementedError()
151 @abstractmethod
152 def getDiaObjects_table(self) -> ApdbTables:
153 """Return type of table returned from getDiaObjects method."""
154 raise NotImplementedError()
156 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
157 """Validate catalog type and size
159 Parameters
160 ----------
161 catalog : `object`
162 Expected type of this is ``pandas.DataFrame``.
163 rows : `int`
164 Expected number of rows in a catalog.
165 table : `ApdbTables`
166 APDB table type.
167 """
168 self.assertIsInstance(catalog, pandas.DataFrame)
169 self.assertEqual(catalog.shape[0], rows)
170 self.assertEqual(catalog.shape[1], self.table_column_count[table])
172 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
173 """Validate catalog type and size
175 Parameters
176 ----------
177 catalog : `object`
178 Expected type of this is `ApdbTableData`.
179 rows : `int`
180 Expected number of rows in a catalog.
181 table : `ApdbTables`
182 APDB table type.
183 extra_columns : `int`
184 Count of additional columns expected in ``catalog``.
185 """
186 self.assertIsInstance(catalog, ApdbTableData)
187 n_rows = sum(1 for row in catalog.rows())
188 self.assertEqual(n_rows, rows)
189 # One extra column for insert_id
190 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1)
192 def test_makeSchema(self) -> None:
193 """Test for making APDB schema."""
194 config = self.make_instance()
195 apdb = Apdb.from_config(config)
197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
199 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
200 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
201 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
203 # Test from_uri factory method with the same config.
204 with tempfile.NamedTemporaryFile() as tmpfile:
205 config.save(tmpfile.name)
206 apdb = Apdb.from_uri(tmpfile.name)
208 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
209 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
210 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
211 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
212 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
214 def test_empty_gets(self) -> None:
215 """Test for getting data from empty database.
217 All get() methods should return empty results, only useful for
218 checking that code is not broken.
219 """
220 # use non-zero months for Forced/Source fetching
221 config = self.make_instance()
222 apdb = Apdb.from_config(config)
224 region = _make_region()
225 visit_time = self.visit_time
227 res: pandas.DataFrame | None
229 # get objects by region
230 res = apdb.getDiaObjects(region)
231 self.assert_catalog(res, 0, self.getDiaObjects_table())
233 # get sources by region
234 res = apdb.getDiaSources(region, None, visit_time)
235 self.assert_catalog(res, 0, ApdbTables.DiaSource)
237 res = apdb.getDiaSources(region, [], visit_time)
238 self.assert_catalog(res, 0, ApdbTables.DiaSource)
240 # get sources by object ID, non-empty object list
241 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
242 self.assert_catalog(res, 0, ApdbTables.DiaSource)
244 # get forced sources by object ID, empty object list
245 res = apdb.getDiaForcedSources(region, [], visit_time)
246 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
248 # get sources by object ID, non-empty object list
249 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
250 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
252 # test if a visit has objects/sources
253 if self.allow_visit_query:
254 res = apdb.containsVisitDetector(visit=0, detector=0)
255 self.assertFalse(res)
256 else:
257 with self.assertRaises(NotImplementedError):
258 apdb.containsVisitDetector(visit=0, detector=0)
260 # alternative method not part of the Apdb API
261 if isinstance(apdb, ApdbSql):
262 res = apdb.containsCcdVisit(1)
263 self.assertFalse(res)
265 # get sources by region
266 if self.fsrc_requires_id_list:
267 with self.assertRaises(NotImplementedError):
268 apdb.getDiaForcedSources(region, None, visit_time)
269 else:
270 apdb.getDiaForcedSources(region, None, visit_time)
271 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
273 def test_empty_gets_0months(self) -> None:
274 """Test for getting data from empty database.
276 All get() methods should return empty DataFrame or None.
277 """
278 # set read_sources_months to 0 so that Forced/Sources are None
279 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0)
280 apdb = Apdb.from_config(config)
282 region = _make_region()
283 visit_time = self.visit_time
285 res: pandas.DataFrame | None
287 # get objects by region
288 res = apdb.getDiaObjects(region)
289 self.assert_catalog(res, 0, self.getDiaObjects_table())
291 # get sources by region
292 res = apdb.getDiaSources(region, None, visit_time)
293 self.assertIs(res, None)
295 # get sources by object ID, empty object list
296 res = apdb.getDiaSources(region, [], visit_time)
297 self.assertIs(res, None)
299 # get forced sources by object ID, empty object list
300 res = apdb.getDiaForcedSources(region, [], visit_time)
301 self.assertIs(res, None)
303 # test if a visit has objects/sources
304 if self.allow_visit_query:
305 res = apdb.containsVisitDetector(visit=0, detector=0)
306 self.assertFalse(res)
307 else:
308 with self.assertRaises(NotImplementedError):
309 apdb.containsVisitDetector(visit=0, detector=0)
311 # alternative method not part of the Apdb API
312 if isinstance(apdb, ApdbSql):
313 res = apdb.containsCcdVisit(1)
314 self.assertFalse(res)
316 def test_storeObjects(self) -> None:
317 """Store and retrieve DiaObjects."""
318 # don't care about sources.
319 config = self.make_instance()
320 apdb = Apdb.from_config(config)
322 region = _make_region()
323 visit_time = self.visit_time
325 # make catalog with Objects
326 catalog = makeObjectCatalog(region, 100, visit_time)
328 # store catalog
329 apdb.store(visit_time, catalog)
331 # read it back and check sizes
332 res = apdb.getDiaObjects(region)
333 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
335 # TODO: test apdb.contains with generic implementation from DM-41671
337 def test_storeObjects_empty(self) -> None:
338 """Test calling storeObject when there are no objects: see DM-43270."""
339 config = self.make_instance()
340 apdb = Apdb.from_config(config)
341 region = _make_region()
342 visit_time = self.visit_time
343 # make catalog with no Objects
344 catalog = makeObjectCatalog(region, 0, visit_time)
346 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm:
347 apdb.store(visit_time, catalog)
348 self.assertIn("No objects", "\n".join(cm.output))
350 def test_storeSources(self) -> None:
351 """Store and retrieve DiaSources."""
352 config = self.make_instance()
353 apdb = Apdb.from_config(config)
355 region = _make_region()
356 visit_time = self.visit_time
358 # have to store Objects first
359 objects = makeObjectCatalog(region, 100, visit_time)
360 oids = list(objects["diaObjectId"])
361 sources = makeSourceCatalog(objects, visit_time)
363 # save the objects and sources
364 apdb.store(visit_time, objects, sources)
366 # read it back, no ID filtering
367 res = apdb.getDiaSources(region, None, visit_time)
368 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
370 # read it back and filter by ID
371 res = apdb.getDiaSources(region, oids, visit_time)
372 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
374 # read it back to get schema
375 res = apdb.getDiaSources(region, [], visit_time)
376 self.assert_catalog(res, 0, ApdbTables.DiaSource)
378 # test if a visit is present
379 # data_factory's ccdVisitId generation corresponds to (0, 0)
380 if self.allow_visit_query:
381 res = apdb.containsVisitDetector(visit=0, detector=0)
382 self.assertTrue(res)
383 else:
384 with self.assertRaises(NotImplementedError):
385 apdb.containsVisitDetector(visit=0, detector=0)
387 # alternative method not part of the Apdb API
388 if isinstance(apdb, ApdbSql):
389 res = apdb.containsCcdVisit(1)
390 self.assertTrue(res)
391 res = apdb.containsCcdVisit(42)
392 self.assertFalse(res)
394 def test_storeForcedSources(self) -> None:
395 """Store and retrieve DiaForcedSources."""
396 config = self.make_instance()
397 apdb = Apdb.from_config(config)
399 region = _make_region()
400 visit_time = self.visit_time
402 # have to store Objects first
403 objects = makeObjectCatalog(region, 100, visit_time)
404 oids = list(objects["diaObjectId"])
405 catalog = makeForcedSourceCatalog(objects, visit_time)
407 apdb.store(visit_time, objects, forced_sources=catalog)
409 # read it back and check sizes
410 res = apdb.getDiaForcedSources(region, oids, visit_time)
411 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
413 # read it back to get schema
414 res = apdb.getDiaForcedSources(region, [], visit_time)
415 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
417 # TODO: test apdb.contains with generic implementation from DM-41671
419 # alternative method not part of the Apdb API
420 if isinstance(apdb, ApdbSql):
421 res = apdb.containsCcdVisit(1)
422 self.assertTrue(res)
423 res = apdb.containsCcdVisit(42)
424 self.assertFalse(res)
426 def test_getHistory(self) -> None:
427 """Store and retrieve catalog history."""
428 # don't care about sources.
429 config = self.make_instance()
430 apdb = Apdb.from_config(config)
431 visit_time = self.visit_time
433 region1 = _make_region((1.0, 1.0, -1.0))
434 region2 = _make_region((-1.0, -1.0, -1.0))
435 nobj = 100
436 objects1 = makeObjectCatalog(region1, nobj, visit_time)
437 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
439 visits = [
440 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1),
441 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2),
442 (astropy.time.Time("2021-01-01T00:03:00", format="isot", scale="tai"), objects1),
443 (astropy.time.Time("2021-01-01T00:04:00", format="isot", scale="tai"), objects2),
444 (astropy.time.Time("2021-01-01T00:05:00", format="isot", scale="tai"), objects1),
445 (astropy.time.Time("2021-01-01T00:06:00", format="isot", scale="tai"), objects2),
446 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1),
447 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2),
448 ]
450 start_id = 0
451 for visit_time, objects in visits:
452 sources = makeSourceCatalog(objects, visit_time, start_id=start_id)
453 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id)
454 apdb.store(visit_time, objects, sources, fsources)
455 start_id += nobj
457 insert_ids = apdb.getInsertIds()
458 if not self.use_insert_id:
459 self.assertIsNone(insert_ids)
461 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"):
462 apdb.getDiaObjectsHistory([])
464 else:
465 assert insert_ids is not None
466 self.assertEqual(len(insert_ids), 8)
468 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None:
469 if n_records is None:
470 n_records = len(insert_ids) * nobj
471 res = apdb.getDiaObjectsHistory(insert_ids)
472 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
473 res = apdb.getDiaSourcesHistory(insert_ids)
474 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
475 res = apdb.getDiaForcedSourcesHistory(insert_ids)
476 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
478 # read it back and check sizes
479 _check_history(insert_ids)
480 _check_history(insert_ids[1:])
481 _check_history(insert_ids[1:-1])
482 _check_history(insert_ids[3:4])
483 _check_history([])
485 # try to remove some of those
486 deleted_ids = insert_ids[:2]
487 apdb.deleteInsertIds(deleted_ids)
489 # All queries on deleted ids should return empty set.
490 _check_history(deleted_ids, 0)
492 insert_ids = apdb.getInsertIds()
493 assert insert_ids is not None
494 self.assertEqual(len(insert_ids), 6)
496 _check_history(insert_ids)
498 def test_storeSSObjects(self) -> None:
499 """Store and retrieve SSObjects."""
500 # don't care about sources.
501 config = self.make_instance()
502 apdb = Apdb.from_config(config)
504 # make catalog with SSObjects
505 catalog = makeSSObjectCatalog(100, flags=1)
507 # store catalog
508 apdb.storeSSObjects(catalog)
510 # read it back and check sizes
511 res = apdb.getSSObjects()
512 self.assert_catalog(res, len(catalog), ApdbTables.SSObject)
514 # check that override works, make catalog with SSObjects, ID = 51-150
515 catalog = makeSSObjectCatalog(100, 51, flags=2)
516 apdb.storeSSObjects(catalog)
517 res = apdb.getSSObjects()
518 self.assert_catalog(res, 150, ApdbTables.SSObject)
519 self.assertEqual(len(res[res["flags"] == 1]), 50)
520 self.assertEqual(len(res[res["flags"] == 2]), 100)
522 def test_reassignObjects(self) -> None:
523 """Reassign DiaObjects."""
524 # don't care about sources.
525 config = self.make_instance()
526 apdb = Apdb.from_config(config)
528 region = _make_region()
529 visit_time = self.visit_time
530 objects = makeObjectCatalog(region, 100, visit_time)
531 oids = list(objects["diaObjectId"])
532 sources = makeSourceCatalog(objects, visit_time)
533 apdb.store(visit_time, objects, sources)
535 catalog = makeSSObjectCatalog(100)
536 apdb.storeSSObjects(catalog)
538 # read it back and filter by ID
539 res = apdb.getDiaSources(region, oids, visit_time)
540 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
542 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
543 res = apdb.getDiaSources(region, oids, visit_time)
544 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
546 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
547 apdb.reassignDiaSources(
548 {
549 1000: 1,
550 7: 3,
551 }
552 )
553 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
555 def test_midpointMjdTai_src(self) -> None:
556 """Test for time filtering of DiaSources."""
557 config = self.make_instance()
558 apdb = Apdb.from_config(config)
560 region = _make_region()
561 # 2021-01-01 plus 360 days is 2021-12-27
562 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
563 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
564 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
565 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
566 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
568 objects = makeObjectCatalog(region, 100, visit_time0)
569 oids = list(objects["diaObjectId"])
570 sources = makeSourceCatalog(objects, src_time1, 0)
571 apdb.store(src_time1, objects, sources)
573 sources = makeSourceCatalog(objects, src_time2, 100)
574 apdb.store(src_time2, objects, sources)
576 # reading at time of last save should read all
577 res = apdb.getDiaSources(region, oids, src_time2)
578 self.assert_catalog(res, 200, ApdbTables.DiaSource)
580 # one second before 12 months
581 res = apdb.getDiaSources(region, oids, visit_time0)
582 self.assert_catalog(res, 200, ApdbTables.DiaSource)
584 # reading at later time of last save should only read a subset
585 res = apdb.getDiaSources(region, oids, visit_time1)
586 self.assert_catalog(res, 100, ApdbTables.DiaSource)
588 # reading at later time of last save should only read a subset
589 res = apdb.getDiaSources(region, oids, visit_time2)
590 self.assert_catalog(res, 0, ApdbTables.DiaSource)
592 def test_midpointMjdTai_fsrc(self) -> None:
593 """Test for time filtering of DiaForcedSources."""
594 config = self.make_instance()
595 apdb = Apdb.from_config(config)
597 region = _make_region()
598 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
599 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
600 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
601 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
602 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
604 objects = makeObjectCatalog(region, 100, visit_time0)
605 oids = list(objects["diaObjectId"])
606 sources = makeForcedSourceCatalog(objects, src_time1, 1)
607 apdb.store(src_time1, objects, forced_sources=sources)
609 sources = makeForcedSourceCatalog(objects, src_time2, 2)
610 apdb.store(src_time2, objects, forced_sources=sources)
612 # reading at time of last save should read all
613 res = apdb.getDiaForcedSources(region, oids, src_time2)
614 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
616 # one second before 12 months
617 res = apdb.getDiaForcedSources(region, oids, visit_time0)
618 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
620 # reading at later time of last save should only read a subset
621 res = apdb.getDiaForcedSources(region, oids, visit_time1)
622 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
624 # reading at later time of last save should only read a subset
625 res = apdb.getDiaForcedSources(region, oids, visit_time2)
626 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
628 def test_metadata(self) -> None:
629 """Simple test for writing/reading metadata table"""
630 config = self.make_instance()
631 apdb = Apdb.from_config(config)
632 metadata = apdb.metadata
634 # APDB should write two metadata items with version numbers and a
635 # frozen JSON config.
636 self.assertFalse(metadata.empty())
637 self.assertEqual(len(list(metadata.items())), 3)
639 metadata.set("meta", "data")
640 metadata.set("data", "meta")
642 self.assertFalse(metadata.empty())
643 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")})
645 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"):
646 metadata.set("meta", "data1")
648 metadata.set("meta", "data2", force=True)
649 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")})
651 self.assertTrue(metadata.delete("meta"))
652 self.assertIsNone(metadata.get("meta"))
653 self.assertFalse(metadata.delete("meta"))
655 self.assertEqual(metadata.get("data"), "meta")
656 self.assertEqual(metadata.get("meta", "meta"), "meta")
658 def test_nometadata(self) -> None:
659 """Test case for when metadata table is missing"""
660 # We expect that schema includes metadata table, drop it.
661 with update_schema_yaml(self.schema_path, drop_metadata=True) as schema_file:
662 config = self.make_instance(schema_file=schema_file)
663 apdb = Apdb.from_config(config)
664 metadata = apdb.metadata
666 self.assertTrue(metadata.empty())
667 self.assertEqual(list(metadata.items()), [])
668 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"):
669 metadata.set("meta", "data")
671 self.assertTrue(metadata.empty())
672 self.assertIsNone(metadata.get("meta"))
674 # Also check what happens when configured schema has metadata, but
675 # database is missing it. Database was initialized inside above context
676 # without metadata table, here we use schema config which includes
677 # metadata table.
678 config.schema_file = self.schema_path
679 apdb = Apdb.from_config(config)
680 metadata = apdb.metadata
681 self.assertTrue(metadata.empty())
683 def test_schemaVersionFromYaml(self) -> None:
684 """Check version number handling for reading schema from YAML."""
685 config = self.make_instance()
686 default_schema = config.schema_file
687 apdb = Apdb.from_config(config)
688 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1))
690 with update_schema_yaml(default_schema, version="") as schema_file:
691 config = self.make_instance(schema_file=schema_file)
692 apdb = Apdb.from_config(config)
693 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0))
695 with update_schema_yaml(default_schema, version="99.0.0") as schema_file:
696 config = self.make_instance(schema_file=schema_file)
697 apdb = Apdb.from_config(config)
698 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0))
700 def test_config_freeze(self) -> None:
701 """Test that some config fields are correctly frozen in database."""
702 config = self.make_instance()
704 # `use_insert_id` is the only parameter that is frozen in all
705 # implementations.
706 config.use_insert_id = not self.use_insert_id
707 apdb = Apdb.from_config(config)
708 frozen_config = apdb.config # type: ignore[attr-defined]
709 self.assertEqual(frozen_config.use_insert_id, self.use_insert_id)
712class ApdbSchemaUpdateTest(TestCaseMixin, ABC):
713 """Base class for unit tests that verify how schema changes work."""
715 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
717 @abstractmethod
718 def make_instance(self, **kwargs: Any) -> ApdbConfig:
719 """Make config class instance used in all tests.
721 This method should return configuration that point to the identical
722 database instance on each call (i.e. ``db_url`` must be the same,
723 which also means for sqlite it has to use on-disk storage).
724 """
725 raise NotImplementedError()
727 def test_schema_add_history(self) -> None:
728 """Check that new code can work with old schema without history
729 tables.
730 """
731 # Make schema without history tables.
732 config = self.make_instance(use_insert_id=False)
733 apdb = Apdb.from_config(config)
735 # Make APDB instance configured for history tables.
736 config.use_insert_id = True
737 apdb = Apdb.from_config(config)
739 # Try to insert something, should work OK.
740 region = _make_region()
741 visit_time = self.visit_time
743 # have to store Objects first
744 objects = makeObjectCatalog(region, 100, visit_time)
745 sources = makeSourceCatalog(objects, visit_time)
746 fsources = makeForcedSourceCatalog(objects, visit_time)
747 apdb.store(visit_time, objects, sources, fsources)
749 # There should be no history.
750 insert_ids = apdb.getInsertIds()
751 self.assertIsNone(insert_ids)
753 def test_schemaVersionCheck(self) -> None:
754 """Check version number compatibility."""
755 config = self.make_instance()
756 apdb = Apdb.from_config(config)
758 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1))
760 # Claim that schema version is now 99.0.0, must raise an exception.
761 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file:
762 config.schema_file = schema_file
763 with self.assertRaises(IncompatibleVersionError):
764 apdb = Apdb.from_config(config)