Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%
387 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 03:30 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 03:30 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"]
26import contextlib
27import os
28import tempfile
29import unittest
30from abc import ABC, abstractmethod
31from collections.abc import Iterator
32from tempfile import TemporaryDirectory
33from typing import TYPE_CHECKING, Any
35import astropy.time
36import pandas
37import yaml
38from lsst.dax.apdb import (
39 Apdb,
40 ApdbConfig,
41 ApdbReplica,
42 ApdbTableData,
43 ApdbTables,
44 IncompatibleVersionError,
45 ReplicaChunk,
46 VersionTuple,
47)
48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d
50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog
52if TYPE_CHECKING:
54 class TestCaseMixin(unittest.TestCase):
55 """Base class for mixin test classes that use TestCase methods."""
57else:
59 class TestCaseMixin:
60 """Do-nothing definition of mixin base class for regular execution."""
63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
64 """Make a region to use in tests"""
65 pointing_v = UnitVector3d(*xyz)
66 fov = 0.05 # radians
67 region = Circle(pointing_v, Angle(fov / 2))
68 return region
71@contextlib.contextmanager
72def update_schema_yaml(
73 schema_file: str,
74 drop_metadata: bool = False,
75 version: str | None = None,
76) -> Iterator[str]:
77 """Update schema definition and return name of the new schema file.
79 Parameters
80 ----------
81 schema_file : `str`
82 Path for the existing YAML file with APDB schema.
83 drop_metadata : `bool`
84 If `True` then remove metadata table from the list of tables.
85 version : `str` or `None`
86 If non-empty string then set schema version to this string, if empty
87 string then remove schema version from config, if `None` - don't change
88 the version in config.
90 Yields
91 ------
92 Path for the updated configuration file.
93 """
94 with open(schema_file) as yaml_stream:
95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
96 # Edit YAML contents.
97 for schema in schemas_list:
98 # Optionally drop metadata table.
99 if drop_metadata:
100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"]
101 if version is not None:
102 if version == "":
103 del schema["version"]
104 else:
105 schema["version"] = version
107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
108 output_path = os.path.join(tmpdir, "schema.yaml")
109 with open(output_path, "w") as yaml_stream:
110 yaml.dump_all(schemas_list, stream=yaml_stream)
111 yield output_path
114class ApdbTest(TestCaseMixin, ABC):
115 """Base class for Apdb tests that can be specialized for concrete
116 implementation.
118 This can only be used as a mixin class for a unittest.TestCase and it
119 calls various assert methods.
120 """
122 time_partition_tables = False
123 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
125 fsrc_requires_id_list = False
126 """Should be set to True if getDiaForcedSources requires object IDs"""
128 enable_replica: bool = False
129 """Set to true when support for replication is configured"""
131 allow_visit_query: bool = True
132 """Set to true when contains is implemented"""
134 schema_path: str
135 """Location of the Felis schema file."""
137 # number of columns as defined in tests/config/schema.yaml
138 table_column_count = {
139 ApdbTables.DiaObject: 8,
140 ApdbTables.DiaObjectLast: 5,
141 ApdbTables.DiaSource: 11,
142 ApdbTables.DiaForcedSource: 5,
143 ApdbTables.SSObject: 3,
144 }
146 @abstractmethod
147 def make_instance(self, **kwargs: Any) -> ApdbConfig:
148 """Make database instance and return configuration for it."""
149 raise NotImplementedError()
151 @abstractmethod
152 def getDiaObjects_table(self) -> ApdbTables:
153 """Return type of table returned from getDiaObjects method."""
154 raise NotImplementedError()
156 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
157 """Validate catalog type and size
159 Parameters
160 ----------
161 catalog : `object`
162 Expected type of this is ``pandas.DataFrame``.
163 rows : `int`
164 Expected number of rows in a catalog.
165 table : `ApdbTables`
166 APDB table type.
167 """
168 self.assertIsInstance(catalog, pandas.DataFrame)
169 self.assertEqual(catalog.shape[0], rows)
170 self.assertEqual(catalog.shape[1], self.table_column_count[table])
172 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
173 """Validate catalog type and size
175 Parameters
176 ----------
177 catalog : `object`
178 Expected type of this is `ApdbTableData`.
179 rows : `int`
180 Expected number of rows in a catalog.
181 table : `ApdbTables`
182 APDB table type.
183 extra_columns : `int`
184 Count of additional columns expected in ``catalog``.
185 """
186 self.assertIsInstance(catalog, ApdbTableData)
187 n_rows = sum(1 for row in catalog.rows())
188 self.assertEqual(n_rows, rows)
189 # One extra column for replica chunk id
190 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1)
192 def test_makeSchema(self) -> None:
193 """Test for making APDB schema."""
194 config = self.make_instance()
195 apdb = Apdb.from_config(config)
197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
199 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
200 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
201 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
203 # Test from_uri factory method with the same config.
204 with tempfile.NamedTemporaryFile() as tmpfile:
205 config.save(tmpfile.name)
206 apdb = Apdb.from_uri(tmpfile.name)
208 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
209 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
210 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
211 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
212 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
214 def test_empty_gets(self) -> None:
215 """Test for getting data from empty database.
217 All get() methods should return empty results, only useful for
218 checking that code is not broken.
219 """
220 # use non-zero months for Forced/Source fetching
221 config = self.make_instance()
222 apdb = Apdb.from_config(config)
224 region = _make_region()
225 visit_time = self.visit_time
227 res: pandas.DataFrame | None
229 # get objects by region
230 res = apdb.getDiaObjects(region)
231 self.assert_catalog(res, 0, self.getDiaObjects_table())
233 # get sources by region
234 res = apdb.getDiaSources(region, None, visit_time)
235 self.assert_catalog(res, 0, ApdbTables.DiaSource)
237 res = apdb.getDiaSources(region, [], visit_time)
238 self.assert_catalog(res, 0, ApdbTables.DiaSource)
240 # get sources by object ID, non-empty object list
241 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
242 self.assert_catalog(res, 0, ApdbTables.DiaSource)
244 # get forced sources by object ID, empty object list
245 res = apdb.getDiaForcedSources(region, [], visit_time)
246 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
248 # get sources by object ID, non-empty object list
249 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
250 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
252 # data_factory's ccdVisitId generation corresponds to (1, 1)
253 if self.allow_visit_query:
254 res = apdb.containsVisitDetector(visit=1, detector=1)
255 self.assertFalse(res)
256 else:
257 with self.assertRaises(NotImplementedError):
258 apdb.containsVisitDetector(visit=0, detector=0)
260 # get sources by region
261 if self.fsrc_requires_id_list:
262 with self.assertRaises(NotImplementedError):
263 apdb.getDiaForcedSources(region, None, visit_time)
264 else:
265 apdb.getDiaForcedSources(region, None, visit_time)
266 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
268 def test_empty_gets_0months(self) -> None:
269 """Test for getting data from empty database.
271 All get() methods should return empty DataFrame or None.
272 """
273 # set read_sources_months to 0 so that Forced/Sources are None
274 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0)
275 apdb = Apdb.from_config(config)
277 region = _make_region()
278 visit_time = self.visit_time
280 res: pandas.DataFrame | None
282 # get objects by region
283 res = apdb.getDiaObjects(region)
284 self.assert_catalog(res, 0, self.getDiaObjects_table())
286 # get sources by region
287 res = apdb.getDiaSources(region, None, visit_time)
288 self.assertIs(res, None)
290 # get sources by object ID, empty object list
291 res = apdb.getDiaSources(region, [], visit_time)
292 self.assertIs(res, None)
294 # get forced sources by object ID, empty object list
295 res = apdb.getDiaForcedSources(region, [], visit_time)
296 self.assertIs(res, None)
298 # test if a visit has objects/sources
299 if self.allow_visit_query:
300 # Database is empty, no images exist.
301 res = apdb.containsVisitDetector(visit=1, detector=1)
302 self.assertFalse(res)
303 else:
304 with self.assertRaises(NotImplementedError):
305 apdb.containsVisitDetector(visit=0, detector=0)
307 def test_storeObjects(self) -> None:
308 """Store and retrieve DiaObjects."""
309 # don't care about sources.
310 config = self.make_instance()
311 apdb = Apdb.from_config(config)
313 region = _make_region()
314 visit_time = self.visit_time
316 # make catalog with Objects
317 catalog = makeObjectCatalog(region, 100, visit_time)
319 # store catalog
320 apdb.store(visit_time, catalog)
322 # read it back and check sizes
323 res = apdb.getDiaObjects(region)
324 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
326 # TODO: test apdb.contains with generic implementation from DM-41671
328 def test_storeObjects_empty(self) -> None:
329 """Test calling storeObject when there are no objects: see DM-43270."""
330 config = self.make_instance()
331 apdb = Apdb.from_config(config)
332 region = _make_region()
333 visit_time = self.visit_time
334 # make catalog with no Objects
335 catalog = makeObjectCatalog(region, 0, visit_time)
337 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm:
338 apdb.store(visit_time, catalog)
339 self.assertIn("No objects", "\n".join(cm.output))
341 def test_storeSources(self) -> None:
342 """Store and retrieve DiaSources."""
343 config = self.make_instance()
344 apdb = Apdb.from_config(config)
346 region = _make_region()
347 visit_time = self.visit_time
349 # have to store Objects first
350 objects = makeObjectCatalog(region, 100, visit_time)
351 oids = list(objects["diaObjectId"])
352 sources = makeSourceCatalog(objects, visit_time)
354 # save the objects and sources
355 apdb.store(visit_time, objects, sources)
357 # read it back, no ID filtering
358 res = apdb.getDiaSources(region, None, visit_time)
359 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
361 # read it back and filter by ID
362 res = apdb.getDiaSources(region, oids, visit_time)
363 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
365 # read it back to get schema
366 res = apdb.getDiaSources(region, [], visit_time)
367 self.assert_catalog(res, 0, ApdbTables.DiaSource)
369 # test if a visit is present
370 # data_factory's ccdVisitId generation corresponds to (1, 1)
371 if self.allow_visit_query:
372 res = apdb.containsVisitDetector(visit=1, detector=1)
373 self.assertTrue(res)
374 # non-existent image
375 res = apdb.containsVisitDetector(visit=2, detector=42)
376 self.assertFalse(res)
377 else:
378 with self.assertRaises(NotImplementedError):
379 apdb.containsVisitDetector(visit=0, detector=0)
381 def test_storeForcedSources(self) -> None:
382 """Store and retrieve DiaForcedSources."""
383 config = self.make_instance()
384 apdb = Apdb.from_config(config)
386 region = _make_region()
387 visit_time = self.visit_time
389 # have to store Objects first
390 objects = makeObjectCatalog(region, 100, visit_time)
391 oids = list(objects["diaObjectId"])
392 catalog = makeForcedSourceCatalog(objects, visit_time)
394 apdb.store(visit_time, objects, forced_sources=catalog)
396 # read it back and check sizes
397 res = apdb.getDiaForcedSources(region, oids, visit_time)
398 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
400 # read it back to get schema
401 res = apdb.getDiaForcedSources(region, [], visit_time)
402 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
404 # data_factory's ccdVisitId generation corresponds to (1, 1)
405 if self.allow_visit_query:
406 res = apdb.containsVisitDetector(visit=1, detector=1)
407 self.assertTrue(res)
408 # non-existent image
409 res = apdb.containsVisitDetector(visit=2, detector=42)
410 self.assertFalse(res)
411 else:
412 with self.assertRaises(NotImplementedError):
413 apdb.containsVisitDetector(visit=0, detector=0)
415 def test_getChunks(self) -> None:
416 """Store and retrieve replica chunks."""
417 # don't care about sources.
418 config = self.make_instance()
419 apdb = Apdb.from_config(config)
420 apdb_replica = ApdbReplica.from_config(config)
421 visit_time = self.visit_time
423 region1 = _make_region((1.0, 1.0, -1.0))
424 region2 = _make_region((-1.0, -1.0, -1.0))
425 nobj = 100
426 objects1 = makeObjectCatalog(region1, nobj, visit_time)
427 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
429 # With the default 10 minutes replica chunk window we should have 4
430 # records.
431 visits = [
432 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1),
433 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2),
434 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1),
435 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2),
436 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1),
437 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2),
438 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1),
439 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2),
440 ]
442 start_id = 0
443 for visit_time, objects in visits:
444 sources = makeSourceCatalog(objects, visit_time, start_id=start_id)
445 fsources = makeForcedSourceCatalog(objects, visit_time, visit=start_id)
446 apdb.store(visit_time, objects, sources, fsources)
447 start_id += nobj
449 replica_chunks = apdb_replica.getReplicaChunks()
450 if not self.enable_replica:
451 self.assertIsNone(replica_chunks)
453 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"):
454 apdb_replica.getDiaObjectsChunks([])
456 else:
457 assert replica_chunks is not None
458 self.assertEqual(len(replica_chunks), 4)
460 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None:
461 if n_records is None:
462 n_records = len(replica_chunks) * nobj
463 res = apdb_replica.getDiaObjectsChunks(chunk.id for chunk in replica_chunks)
464 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
465 res = apdb_replica.getDiaSourcesChunks(chunk.id for chunk in replica_chunks)
466 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
467 res = apdb_replica.getDiaForcedSourcesChunks(chunk.id for chunk in replica_chunks)
468 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
470 # read it back and check sizes
471 _check_chunks(replica_chunks, 800)
472 _check_chunks(replica_chunks[1:], 600)
473 _check_chunks(replica_chunks[1:-1], 400)
474 _check_chunks(replica_chunks[2:3], 200)
475 _check_chunks([])
477 # try to remove some of those
478 deleted_chunks = replica_chunks[:1]
479 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks)
481 # All queries on deleted ids should return empty set.
482 _check_chunks(deleted_chunks, 0)
484 replica_chunks = apdb_replica.getReplicaChunks()
485 assert replica_chunks is not None
486 self.assertEqual(len(replica_chunks), 3)
488 _check_chunks(replica_chunks, 600)
490 def test_storeSSObjects(self) -> None:
491 """Store and retrieve SSObjects."""
492 # don't care about sources.
493 config = self.make_instance()
494 apdb = Apdb.from_config(config)
496 # make catalog with SSObjects
497 catalog = makeSSObjectCatalog(100, flags=1)
499 # store catalog
500 apdb.storeSSObjects(catalog)
502 # read it back and check sizes
503 res = apdb.getSSObjects()
504 self.assert_catalog(res, len(catalog), ApdbTables.SSObject)
506 # check that override works, make catalog with SSObjects, ID = 51-150
507 catalog = makeSSObjectCatalog(100, 51, flags=2)
508 apdb.storeSSObjects(catalog)
509 res = apdb.getSSObjects()
510 self.assert_catalog(res, 150, ApdbTables.SSObject)
511 self.assertEqual(len(res[res["flags"] == 1]), 50)
512 self.assertEqual(len(res[res["flags"] == 2]), 100)
514 def test_reassignObjects(self) -> None:
515 """Reassign DiaObjects."""
516 # don't care about sources.
517 config = self.make_instance()
518 apdb = Apdb.from_config(config)
520 region = _make_region()
521 visit_time = self.visit_time
522 objects = makeObjectCatalog(region, 100, visit_time)
523 oids = list(objects["diaObjectId"])
524 sources = makeSourceCatalog(objects, visit_time)
525 apdb.store(visit_time, objects, sources)
527 catalog = makeSSObjectCatalog(100)
528 apdb.storeSSObjects(catalog)
530 # read it back and filter by ID
531 res = apdb.getDiaSources(region, oids, visit_time)
532 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
534 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
535 res = apdb.getDiaSources(region, oids, visit_time)
536 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
538 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
539 apdb.reassignDiaSources(
540 {
541 1000: 1,
542 7: 3,
543 }
544 )
545 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
547 def test_midpointMjdTai_src(self) -> None:
548 """Test for time filtering of DiaSources."""
549 config = self.make_instance()
550 apdb = Apdb.from_config(config)
552 region = _make_region()
553 # 2021-01-01 plus 360 days is 2021-12-27
554 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
555 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
556 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
557 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
558 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
560 objects = makeObjectCatalog(region, 100, visit_time0)
561 oids = list(objects["diaObjectId"])
562 sources = makeSourceCatalog(objects, src_time1, 0)
563 apdb.store(src_time1, objects, sources)
565 sources = makeSourceCatalog(objects, src_time2, 100)
566 apdb.store(src_time2, objects, sources)
568 # reading at time of last save should read all
569 res = apdb.getDiaSources(region, oids, src_time2)
570 self.assert_catalog(res, 200, ApdbTables.DiaSource)
572 # one second before 12 months
573 res = apdb.getDiaSources(region, oids, visit_time0)
574 self.assert_catalog(res, 200, ApdbTables.DiaSource)
576 # reading at later time of last save should only read a subset
577 res = apdb.getDiaSources(region, oids, visit_time1)
578 self.assert_catalog(res, 100, ApdbTables.DiaSource)
580 # reading at later time of last save should only read a subset
581 res = apdb.getDiaSources(region, oids, visit_time2)
582 self.assert_catalog(res, 0, ApdbTables.DiaSource)
584 def test_midpointMjdTai_fsrc(self) -> None:
585 """Test for time filtering of DiaForcedSources."""
586 config = self.make_instance()
587 apdb = Apdb.from_config(config)
589 region = _make_region()
590 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
591 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
592 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
593 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
594 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
596 objects = makeObjectCatalog(region, 100, visit_time0)
597 oids = list(objects["diaObjectId"])
598 sources = makeForcedSourceCatalog(objects, src_time1, 1)
599 apdb.store(src_time1, objects, forced_sources=sources)
601 sources = makeForcedSourceCatalog(objects, src_time2, 2)
602 apdb.store(src_time2, objects, forced_sources=sources)
604 # reading at time of last save should read all
605 res = apdb.getDiaForcedSources(region, oids, src_time2)
606 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
608 # one second before 12 months
609 res = apdb.getDiaForcedSources(region, oids, visit_time0)
610 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
612 # reading at later time of last save should only read a subset
613 res = apdb.getDiaForcedSources(region, oids, visit_time1)
614 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
616 # reading at later time of last save should only read a subset
617 res = apdb.getDiaForcedSources(region, oids, visit_time2)
618 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
620 def test_metadata(self) -> None:
621 """Simple test for writing/reading metadata table"""
622 config = self.make_instance()
623 apdb = Apdb.from_config(config)
624 metadata = apdb.metadata
626 # APDB should write two or three metadata items with version numbers
627 # and a frozen JSON config.
628 self.assertFalse(metadata.empty())
629 expected_rows = 4 if self.enable_replica else 3
630 self.assertEqual(len(list(metadata.items())), expected_rows)
632 metadata.set("meta", "data")
633 metadata.set("data", "meta")
635 self.assertFalse(metadata.empty())
636 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")})
638 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"):
639 metadata.set("meta", "data1")
641 metadata.set("meta", "data2", force=True)
642 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")})
644 self.assertTrue(metadata.delete("meta"))
645 self.assertIsNone(metadata.get("meta"))
646 self.assertFalse(metadata.delete("meta"))
648 self.assertEqual(metadata.get("data"), "meta")
649 self.assertEqual(metadata.get("meta", "meta"), "meta")
651 def test_nometadata(self) -> None:
652 """Test case for when metadata table is missing"""
653 # We expect that schema includes metadata table, drop it.
654 with update_schema_yaml(self.schema_path, drop_metadata=True) as schema_file:
655 config = self.make_instance(schema_file=schema_file)
656 apdb = Apdb.from_config(config)
657 metadata = apdb.metadata
659 self.assertTrue(metadata.empty())
660 self.assertEqual(list(metadata.items()), [])
661 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"):
662 metadata.set("meta", "data")
664 self.assertTrue(metadata.empty())
665 self.assertIsNone(metadata.get("meta"))
667 # Also check what happens when configured schema has metadata, but
668 # database is missing it. Database was initialized inside above context
669 # without metadata table, here we use schema config which includes
670 # metadata table.
671 config.schema_file = self.schema_path
672 apdb = Apdb.from_config(config)
673 metadata = apdb.metadata
674 self.assertTrue(metadata.empty())
676 def test_schemaVersionFromYaml(self) -> None:
677 """Check version number handling for reading schema from YAML."""
678 config = self.make_instance()
679 default_schema = config.schema_file
680 apdb = Apdb.from_config(config)
681 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined]
683 with update_schema_yaml(default_schema, version="") as schema_file:
684 config = self.make_instance(schema_file=schema_file)
685 apdb = Apdb.from_config(config)
686 self.assertEqual(
687 apdb._schema.schemaVersion(), VersionTuple(0, 1, 0) # type: ignore[attr-defined]
688 )
690 with update_schema_yaml(default_schema, version="99.0.0") as schema_file:
691 config = self.make_instance(schema_file=schema_file)
692 apdb = Apdb.from_config(config)
693 self.assertEqual(
694 apdb._schema.schemaVersion(), VersionTuple(99, 0, 0) # type: ignore[attr-defined]
695 )
697 def test_config_freeze(self) -> None:
698 """Test that some config fields are correctly frozen in database."""
699 config = self.make_instance()
701 # `use_insert_id` is the only parameter that is frozen in all
702 # implementations.
703 config.use_insert_id = not self.enable_replica
704 apdb = Apdb.from_config(config)
705 frozen_config = apdb.config # type: ignore[attr-defined]
706 self.assertEqual(frozen_config.use_insert_id, self.enable_replica)
709class ApdbSchemaUpdateTest(TestCaseMixin, ABC):
710 """Base class for unit tests that verify how schema changes work."""
712 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
714 @abstractmethod
715 def make_instance(self, **kwargs: Any) -> ApdbConfig:
716 """Make config class instance used in all tests.
718 This method should return configuration that point to the identical
719 database instance on each call (i.e. ``db_url`` must be the same,
720 which also means for sqlite it has to use on-disk storage).
721 """
722 raise NotImplementedError()
724 def test_schema_add_replica(self) -> None:
725 """Check that new code can work with old schema without replica
726 tables.
727 """
728 # Make schema without replica tables.
729 config = self.make_instance(use_insert_id=False)
730 apdb = Apdb.from_config(config)
731 apdb_replica = ApdbReplica.from_config(config)
733 # Make APDB instance configured for replication.
734 config.use_insert_id = True
735 apdb = Apdb.from_config(config)
737 # Try to insert something, should work OK.
738 region = _make_region()
739 visit_time = self.visit_time
741 # have to store Objects first
742 objects = makeObjectCatalog(region, 100, visit_time)
743 sources = makeSourceCatalog(objects, visit_time)
744 fsources = makeForcedSourceCatalog(objects, visit_time)
745 apdb.store(visit_time, objects, sources, fsources)
747 # There should be no replica chunks.
748 replica_chunks = apdb_replica.getReplicaChunks()
749 self.assertIsNone(replica_chunks)
751 def test_schemaVersionCheck(self) -> None:
752 """Check version number compatibility."""
753 config = self.make_instance()
754 apdb = Apdb.from_config(config)
756 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined]
758 # Claim that schema version is now 99.0.0, must raise an exception.
759 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file:
760 config.schema_file = schema_file
761 with self.assertRaises(IncompatibleVersionError):
762 apdb = Apdb.from_config(config)