Coverage for python / lsst / dax / apdb / tests / _apdb.py: 12%
618 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-25 08:20 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-25 08:20 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"]
26import contextlib
27import logging.config
28import os
29import tempfile
30from abc import ABC, abstractmethod
31from collections.abc import Iterator
32from tempfile import TemporaryDirectory
33from typing import TYPE_CHECKING, Any
35import astropy.time
36import felis.datamodel
37import pandas
38import yaml
40from lsst.sphgeom import Angle, Circle, LonLat, Region, UnitVector3d
42from .. import (
43 Apdb,
44 ApdbConfig,
45 ApdbReassignDiaSourceToSSObjectRecord,
46 ApdbReplica,
47 ApdbTableData,
48 ApdbTables,
49 ApdbUpdateRecord,
50 ApdbWithdrawDiaSourceRecord,
51 DiaObjectId,
52 DiaSourceId,
53 IncompatibleVersionError,
54 ReplicaChunk,
55 VersionTuple,
56)
57from .data_factory import (
58 makeForcedSourceCatalog,
59 makeObjectCatalog,
60 makeSourceCatalog,
61 makeTimestamp,
62 makeTimestampColumn,
63)
64from .utils import TestCaseMixin
66if TYPE_CHECKING:
67 from ..pixelization import Pixelization
70# Optionally configure logging from a config file.
71if log_config := os.environ.get("DAX_APDB_TEST_LOG_CONFIG"): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 logging.config.fileConfig(log_config)
75def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
76 """Make a region to use in tests"""
77 pointing_v = UnitVector3d(*xyz)
78 fov = 0.0013 # radians
79 region = Circle(pointing_v, Angle(fov / 2))
80 return region
83@contextlib.contextmanager
84def update_schema_yaml(
85 schema_file: str,
86 drop_metadata: bool = False,
87 version: str | None = None,
88) -> Iterator[str]:
89 """Update schema definition and return name of the new schema file.
91 Parameters
92 ----------
93 schema_file : `str`
94 Path for the existing YAML file with APDB schema.
95 drop_metadata : `bool`
96 If `True` then remove metadata table from the list of tables.
97 version : `str` or `None`
98 If non-empty string then set schema version to this string, if empty
99 string then remove schema version from config, if `None` - don't change
100 the version in config.
102 Yields
103 ------
104 Path for the updated configuration file.
105 """
106 with open(schema_file) as yaml_stream:
107 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
108 # Edit YAML contents.
109 for schema in schemas_list:
110 # Optionally drop metadata table.
111 if drop_metadata:
112 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"]
113 if version is not None:
114 if version == "":
115 del schema["version"]
116 else:
117 schema["version"] = version
119 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
120 output_path = os.path.join(tmpdir, "schema.yaml")
121 with open(output_path, "w") as yaml_stream:
122 yaml.dump_all(schemas_list, stream=yaml_stream)
123 yield output_path
126class ApdbTest(TestCaseMixin, ABC):
127 """Base class for Apdb tests that can be specialized for concrete
128 implementation.
130 This can only be used as a mixin class for a unittest.TestCase and it
131 calls various assert methods.
132 """
134 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
136 processing_time = astropy.time.Time("2021-01-01T12:00:00", format="isot", scale="tai")
138 fsrc_requires_id_list = False
139 """Should be set to True if getDiaForcedSources requires object IDs"""
141 enable_replica: bool = False
142 """Set to true when support for replication is configured"""
144 use_mjd: bool = True
145 """If True then timestamp columns are MJD TAI."""
147 extra_chunk_columns = 1
148 """Number of additional columns in chunk tables."""
150 meta_row_count = 3
151 """Initial row count in metadata table."""
153 # number of columns as defined in tests/config/schema.yaml
154 table_column_count = {
155 ApdbTables.DiaObject: 8,
156 ApdbTables.DiaObjectLast: 6,
157 ApdbTables.DiaSource: 12,
158 ApdbTables.DiaForcedSource: 8,
159 ApdbTables.SSObject: 3,
160 }
162 @abstractmethod
163 def make_instance(self, **kwargs: Any) -> ApdbConfig:
164 """Make database instance and return configuration for it."""
165 raise NotImplementedError()
167 @abstractmethod
168 def getDiaObjects_table(self) -> ApdbTables:
169 """Return type of table returned from getDiaObjects method."""
170 raise NotImplementedError()
172 @abstractmethod
173 def pixelization(self, config: ApdbConfig) -> Pixelization:
174 """Return pixelization used by implementation."""
175 raise NotImplementedError()
177 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
178 """Validate catalog type and size
180 Parameters
181 ----------
182 catalog : `object`
183 Expected type of this is ``pandas.DataFrame``.
184 rows : `int`
185 Expected number of rows in a catalog.
186 table : `ApdbTables`
187 APDB table type.
188 """
189 self.assertIsInstance(catalog, pandas.DataFrame)
190 self.assertEqual(catalog.shape[0], rows)
191 self.assertEqual(catalog.shape[1], self.table_column_count[table])
193 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
194 """Validate catalog type and size
196 Parameters
197 ----------
198 catalog : `object`
199 Expected type of this is `ApdbTableData`.
200 rows : `int`
201 Expected number of rows in a catalog.
202 table : `ApdbTables`
203 APDB table type.
204 extra_columns : `int`
205 Count of additional columns expected in ``catalog``.
206 """
207 self.assertIsInstance(catalog, ApdbTableData)
208 n_rows = sum(1 for row in catalog.rows())
209 self.assertEqual(n_rows, rows)
210 # One extra column for replica chunk id
211 self.assertEqual(
212 len(catalog.column_names()), self.table_column_count[table] + self.extra_chunk_columns
213 )
215 def assert_column_types(self, catalog: Any, types: dict[str, felis.datamodel.DataType]) -> None:
216 column_defs = dict(catalog.column_defs())
217 for column, datatype in types.items():
218 self.assertEqual(column_defs[column], datatype)
220 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
221 """Make a region to use in tests"""
222 return _make_region(xyz)
224 def test_makeSchema(self) -> None:
225 """Test for making APDB schema."""
226 config = self.make_instance()
227 apdb = Apdb.from_config(config)
229 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
230 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
231 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
232 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
233 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
234 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject))
235 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource))
236 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match))
238 # Test from_uri factory method with the same config.
239 with tempfile.NamedTemporaryFile() as tmpfile:
240 config.save(tmpfile.name)
241 apdb = Apdb.from_uri(tmpfile.name)
243 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
244 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
245 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
246 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
247 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
248 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject))
249 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource))
250 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match))
252 def test_empty_gets(self) -> None:
253 """Test for getting data from empty database.
255 All get() methods should return empty results, only useful for
256 checking that code is not broken.
257 """
258 # use non-zero months for Forced/Source fetching
259 config = self.make_instance()
260 apdb = Apdb.from_config(config)
262 region = self.make_region()
263 visit_time = self.visit_time
265 res: pandas.DataFrame | None
267 # get objects by region
268 res = apdb.getDiaObjects(region)
269 self.assert_catalog(res, 0, self.getDiaObjects_table())
271 # get sources by region
272 res = apdb.getDiaSources(region, None, visit_time)
273 self.assert_catalog(res, 0, ApdbTables.DiaSource)
275 res = apdb.getDiaSources(region, [], visit_time)
276 self.assert_catalog(res, 0, ApdbTables.DiaSource)
278 # get sources by object ID, non-empty object list
279 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
280 self.assert_catalog(res, 0, ApdbTables.DiaSource)
282 # get forced sources by object ID, empty object list
283 res = apdb.getDiaForcedSources(region, [], visit_time)
284 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
286 # get sources by object ID, non-empty object list
287 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
288 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
290 # data_factory's ccdVisitId generation corresponds to (1, 1)
291 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
292 self.assertFalse(res)
294 # get sources by region
295 if self.fsrc_requires_id_list:
296 with self.assertRaises(NotImplementedError):
297 apdb.getDiaForcedSources(region, None, visit_time)
298 else:
299 res = apdb.getDiaForcedSources(region, None, visit_time)
300 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
302 def test_empty_gets_0months(self) -> None:
303 """Test for getting data from empty database.
305 All get() methods should return empty DataFrame or None.
306 """
307 # set read_sources_months to 0 so that Forced/Sources are None
308 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0)
309 apdb = Apdb.from_config(config)
311 region = self.make_region()
312 visit_time = self.visit_time
314 res: pandas.DataFrame | None
316 # get objects by region
317 res = apdb.getDiaObjects(region)
318 self.assert_catalog(res, 0, self.getDiaObjects_table())
320 # get sources by region
321 res = apdb.getDiaSources(region, None, visit_time)
322 self.assertIs(res, None)
324 # get sources by object ID, empty object list
325 res = apdb.getDiaSources(region, [], visit_time)
326 self.assertIs(res, None)
328 # get forced sources by object ID, empty object list
329 res = apdb.getDiaForcedSources(region, [], visit_time)
330 self.assertIs(res, None)
332 # Database is empty, no images exist.
333 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
334 self.assertFalse(res)
336 def test_storeObjects(self) -> None:
337 """Store and retrieve DiaObjects."""
338 # don't care about sources.
339 config = self.make_instance()
340 apdb = Apdb.from_config(config)
342 region = self.make_region()
343 visit_time = self.visit_time
345 # make catalog with Objects
346 catalog = makeObjectCatalog(region, 100)
348 # store catalog
349 apdb.store(visit_time, catalog)
351 # read it back and check sizes
352 res = apdb.getDiaObjects(region)
353 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
355 # TODO: test apdb.contains with generic implementation from DM-41671
357 def test_storeObjects_empty(self) -> None:
358 """Test calling storeObject when there are no objects: see DM-43270."""
359 config = self.make_instance()
360 apdb = Apdb.from_config(config)
361 region = self.make_region()
362 visit_time = self.visit_time
363 # make catalog with no Objects
364 catalog = makeObjectCatalog(region, 0)
366 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm:
367 apdb.store(visit_time, catalog)
368 self.assertIn("No objects", "\n".join(cm.output))
370 def test_storeMovingObject(self) -> None:
371 """Store and retrieve DiaObject which changes its position."""
372 # don't care about sources.
373 config = self.make_instance()
374 apdb = Apdb.from_config(config)
375 pixelization = self.pixelization(config)
377 lon_deg, lat_deg = 0.0, 0.0
378 lonlat1 = LonLat.fromDegrees(lon_deg - 1.0, lat_deg)
379 lonlat2 = LonLat.fromDegrees(lon_deg + 1.0, lat_deg)
380 uv1 = UnitVector3d(lonlat1)
381 uv2 = UnitVector3d(lonlat2)
383 # Check that they fall into different pixels.
384 self.assertNotEqual(pixelization.pixel(uv1), pixelization.pixel(uv2))
386 # Store one object at two different positions.
387 visit_time1 = self.visit_time
388 catalog1 = makeObjectCatalog(lonlat1, 1)
389 apdb.store(visit_time1, catalog1)
391 visit_time2 = visit_time1 + astropy.time.TimeDelta(120.0, format="sec")
392 catalog1 = makeObjectCatalog(lonlat2, 1)
393 apdb.store(visit_time2, catalog1)
395 # Make region covering both points.
396 region = Circle(UnitVector3d(LonLat.fromDegrees(lon_deg, lat_deg)), Angle.fromDegrees(1.1))
397 self.assertTrue(region.contains(uv1))
398 self.assertTrue(region.contains(uv2))
400 # Read it back, must return the latest one.
401 res = apdb.getDiaObjects(region)
402 self.assert_catalog(res, 1, self.getDiaObjects_table())
404 def test_storeSources(self) -> None:
405 """Store and retrieve DiaSources."""
406 config = self.make_instance()
407 apdb = Apdb.from_config(config)
409 region = self.make_region()
410 visit_time = self.visit_time
412 # have to store Objects first
413 objects = makeObjectCatalog(region, 100)
414 oids = list(objects["diaObjectId"])
415 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
417 # save the objects and sources
418 apdb.store(visit_time, objects, sources)
420 # read it back, no ID filtering
421 res = apdb.getDiaSources(region, None, visit_time)
422 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
424 # read it back and filter by ID
425 res = apdb.getDiaSources(region, oids, visit_time)
426 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
428 # read it back to get schema
429 res = apdb.getDiaSources(region, [], visit_time)
430 self.assert_catalog(res, 0, ApdbTables.DiaSource)
432 # test if a visit is present
433 # data_factory's ccdVisitId generation corresponds to (1, 1)
434 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
435 self.assertTrue(res)
436 # non-existent image
437 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time)
438 self.assertFalse(res)
440 def test_storeForcedSources(self) -> None:
441 """Store and retrieve DiaForcedSources."""
442 config = self.make_instance()
443 apdb = Apdb.from_config(config)
445 region = self.make_region()
446 visit_time = self.visit_time
448 # have to store Objects first
449 objects = makeObjectCatalog(region, 100)
450 oids = list(objects["diaObjectId"])
451 catalog = makeForcedSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
453 apdb.store(visit_time, objects, forced_sources=catalog)
455 # read it back and check sizes
456 res = apdb.getDiaForcedSources(region, oids, visit_time)
457 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
459 # read it back to get schema
460 res = apdb.getDiaForcedSources(region, [], visit_time)
461 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
463 # data_factory's ccdVisitId generation corresponds to (1, 1)
464 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
465 self.assertTrue(res)
466 # non-existent image
467 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time)
468 self.assertFalse(res)
470 def test_null_integer_type(self) -> None:
471 """Test that integer column with NULLs correct type on select."""
472 config = self.make_instance()
473 apdb = Apdb.from_config(config)
475 region = self.make_region()
476 visit_time = self.visit_time
478 # have to store Objects first
479 objects = makeObjectCatalog(region, 100)
480 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
481 # Reset some diaObjectIds to NULL.
482 sources.loc[0:10, "diaObjectId"] = None
484 # save the objects and sources
485 apdb.store(visit_time, objects, sources)
487 # read it back, no ID filtering
488 res = apdb.getDiaSources(region, None, visit_time)
489 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
490 assert res is not None, "Expecting catalog, not None"
491 self.assertEqual(res.dtypes["diaObjectId"], pandas.Int64Dtype())
493 def test_timestamps(self) -> None:
494 """Check that timestamp return type is as expected."""
495 config = self.make_instance()
496 apdb = Apdb.from_config(config)
498 region = self.make_region()
499 visit_time = self.visit_time
501 # Cassandra has a millisecond precision, so subtract 1ms to allow for
502 # truncated returned values.
503 time_before = makeTimestamp(self.processing_time, self.use_mjd, -1)
504 objects = makeObjectCatalog(region, 100)
505 oids = list(objects["diaObjectId"])
506 catalog = makeForcedSourceCatalog(
507 objects, visit_time, processing_time=self.processing_time, use_mjd=self.use_mjd
508 )
509 time_after = makeTimestamp(self.processing_time, self.use_mjd)
511 apdb.store(visit_time, objects, forced_sources=catalog)
513 # read it back and check sizes
514 res = apdb.getDiaForcedSources(region, oids, visit_time)
515 assert res is not None
516 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
518 time_processed_column = makeTimestampColumn("time_processed", self.use_mjd)
519 self.assertIn(time_processed_column, res.dtypes)
520 dtype = res.dtypes[time_processed_column]
521 timestamp_type_names = (
522 ("float64",) if self.use_mjd else ("datetime64[ms]", "datetime64[us]", "datetime64[ns]")
523 )
524 self.assertIn(dtype.name, timestamp_type_names)
525 # Verify that returned time is sensible.
526 self.assertTrue(all(time_before <= dt <= time_after for dt in res[time_processed_column]))
528 def test_getDiaObjectsForDedup(self) -> None:
529 """Test getDiaObjectsForDedup() method."""
530 config = self.make_instance()
531 apdb = Apdb.from_config(config)
533 region1 = self.make_region((1.0, 1.0, -1.0))
534 region2 = self.make_region((-1.0, 1.0, -1.0))
535 region3 = self.make_region((-1.0, -1.0, -1.0))
536 nobj = 100
537 objects1 = makeObjectCatalog(region1, nobj)
538 objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2)
539 objects3 = makeObjectCatalog(region3, nobj, start_id=nobj * 4)
541 visits = [
542 (astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai"), objects1),
543 (astropy.time.Time("2021-01-01T00:10:00", format="isot", scale="tai"), objects2),
544 (astropy.time.Time("2021-01-01T00:20:00", format="isot", scale="tai"), objects3),
545 ]
547 for visit_time, objects in visits:
548 apdb.store(visit_time, objects)
550 catalog = apdb.getDiaObjectsForDedup()
551 self.assertEqual(len(catalog), 300)
553 catalog = apdb.getDiaObjectsForDedup(visits[0][0])
554 self.assertEqual(len(catalog), 300)
556 catalog = apdb.getDiaObjectsForDedup(visits[1][0])
557 self.assertEqual(len(catalog), 200)
559 catalog = apdb.getDiaObjectsForDedup(visits[2][0])
560 self.assertEqual(len(catalog), 100)
562 time = astropy.time.Time("2021-01-01T00:30:00", format="isot", scale="tai")
563 catalog = apdb.getDiaObjectsForDedup(time)
564 self.assertEqual(len(catalog), 0)
566 def test_getDiaSourcesForDiaObjects(self) -> None:
567 """Test getDiaSourcesForDiaObjects() method."""
568 config = self.make_instance()
569 apdb = Apdb.from_config(config)
570 # Monkey-patch APDB instance to set current time.
571 apdb._current_time = lambda: self.processing_time # type: ignore[method-assign]
573 region1 = self.make_region((1.0, 1.0, -1.0))
574 region2 = self.make_region((-1.0, 1.0, -1.0))
575 region3 = self.make_region((-1.0, -1.0, -1.0))
576 nobj = 100
577 objects1 = makeObjectCatalog(region1, nobj)
578 objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2)
579 objects3 = makeObjectCatalog(region3, nobj, start_id=nobj * 4)
581 visits = [
582 (astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai"), objects1),
583 (astropy.time.Time("2021-01-01T00:10:00", format="isot", scale="tai"), objects2),
584 (astropy.time.Time("2021-01-01T00:20:00", format="isot", scale="tai"), objects3),
585 ]
587 start_id = 1_000_000
588 for visit_time, objects in visits:
589 sources = makeSourceCatalog(objects, visit_time, start_id=start_id, use_mjd=self.use_mjd)
590 apdb.store(visit_time, objects, sources)
591 start_id += 1_000_000
593 # Take a small number of objects from different regions.
594 object_ids = [
595 DiaObjectId.from_named_tuple(next(objects1.itertuples())),
596 DiaObjectId.from_named_tuple(next(objects2.itertuples())),
597 DiaObjectId.from_named_tuple(next(objects3.itertuples())),
598 ]
600 catalog = apdb.getDiaSourcesForDiaObjects(object_ids, visits[0][0])
601 self.assertEqual(len(catalog), 3)
602 self.assertEqual(set(catalog["diaObjectId"]), {1, 200, 400})
603 self.assertEqual(set(catalog["diaSourceId"]), {1_000_000, 2_000_000, 3_000_000})
605 catalog = apdb.getDiaSourcesForDiaObjects(object_ids, visits[2][0])
606 self.assertEqual(len(catalog), 1)
607 self.assertEqual(set(catalog["diaObjectId"]), {400})
608 self.assertEqual(set(catalog["diaSourceId"]), {3_000_000})
610 def test_reassignDiaSourcesToDiaObjects(self) -> None:
611 """Test reassignDiaSourcesToDiaObjects() method."""
612 config = self.make_instance()
613 apdb = Apdb.from_config(config)
614 apdb._current_time = lambda: self.processing_time # type: ignore[method-assign]
615 apdb_replica = ApdbReplica.from_config(config)
617 visit_time = self.visit_time
618 lonlat1 = LonLat.fromDegrees(0.0, 0.0)
619 lonlat2 = LonLat.fromDegrees(180.0, 0.0)
620 # regons around lonlat1/2
621 region1 = self.make_region(xyz=(1.0, 0.0, 0.0))
622 region2 = self.make_region(xyz=(-1.0, 0.0, 0.0))
624 # Store 3 objects and sources at the same position in each region.
625 objects = makeObjectCatalog(lonlat1, 3, start_id=100)
626 sources = makeSourceCatalog(objects, visit_time, start_id=1000, use_mjd=self.use_mjd)
627 apdb.store(visit_time, objects, sources)
629 objects = makeObjectCatalog(lonlat2, 3, start_id=200)
630 sources = makeSourceCatalog(objects, visit_time, start_id=2000, use_mjd=self.use_mjd)
631 apdb.store(visit_time, objects, sources)
633 # check that everything as we think it is.
634 objects = apdb.getDiaObjects(region1)
635 self.assertEqual(set(objects["diaObjectId"]), {100, 101, 102})
636 self.assertEqual(list(objects["nDiaSources"]), [1, 1, 1])
637 sources = apdb.getDiaSources(region1, [100, 101, 102], visit_time)
638 assert sources is not None
639 self.assertEqual(set(sources["diaSourceId"]), {1000, 1001, 1002})
640 self.assertEqual(set(sources["diaObjectId"]), {100, 101, 102})
642 dia_source_ids = [DiaSourceId.from_named_tuple(row) for row in sources.itertuples()]
644 # Reassign sources in region1 and increment/decrement nDiaSources.
645 reassign = {
646 dia_source_id: 100
647 for dia_source_id in dia_source_ids
648 if dia_source_id.diaSourceId in (1001, 1002)
649 }
650 apdb.reassignDiaSourcesToDiaObjects(reassign)
652 objects = apdb.getDiaObjects(region1)
653 self.assertEqual(set(objects["nDiaSources"]), {0, 3})
654 sources = apdb.getDiaSources(region1, [100], visit_time)
655 assert sources is not None
656 self.assertEqual(set(sources["diaSourceId"]), {1000, 1001, 1002})
657 self.assertEqual(set(sources["diaObjectId"]), {100})
659 sources = apdb.getDiaSources(region2, [201, 202], visit_time)
660 assert sources is not None
661 self.assertEqual(set(sources["diaSourceId"]), {2001, 2002})
662 dia_source_ids = [DiaSourceId.from_named_tuple(row) for row in sources.itertuples()]
664 # Reassign but do not increment/decrement nDiaSources.
665 reassign = {
666 dia_source_id: 200
667 for dia_source_id in dia_source_ids
668 if dia_source_id.diaSourceId in (2001, 2002)
669 }
670 apdb.reassignDiaSourcesToDiaObjects(
671 reassign, increment_nDiaSources=False, decrement_nDiaSources=False
672 )
674 objects = apdb.getDiaObjects(region2)
675 self.assertEqual(set(objects["nDiaSources"]), {1})
676 sources = apdb.getDiaSources(region2, [200], visit_time)
677 assert sources is not None
678 self.assertEqual(set(sources["diaSourceId"]), {2000, 2001, 2002})
679 self.assertEqual(set(sources["diaObjectId"]), {200})
681 replica_chunks = apdb_replica.getReplicaChunks()
682 if not self.enable_replica:
683 self.assertIsNone(replica_chunks)
684 else:
685 assert replica_chunks is not None
687 # There could be one or two chunks.
688 self.assertTrue(1 <= len(replica_chunks) <= 2)
690 update_records = apdb_replica.getUpdateRecordChunks([chunk.id for chunk in replica_chunks])
691 # Two reassignments for region1, three increments/decrements for
692 # that region, plus two reassignments for region2 without
693 # increments/decrements.
694 self.assertEqual(len(update_records), 2 + 3 + 2)
696 def test_setValidityEnd(self) -> None:
697 """Store DiaObjects and truncate validity for some."""
698 # don't care about sources.
699 config = self.make_instance()
700 apdb = Apdb.from_config(config)
701 apdb._current_time = lambda: self.processing_time # type: ignore[method-assign]
702 apdb_replica = ApdbReplica.from_config(config)
704 region = self.make_region()
705 visit_time = self.visit_time
707 # make catalog with Objects
708 catalog = makeObjectCatalog(region, 100)
710 # store catalog
711 apdb.store(visit_time, catalog)
713 # read it back and check sizes
714 res = apdb.getDiaObjects(region)
715 self.assert_catalog(res, 100, self.getDiaObjects_table())
717 # Select first 10 objects.
718 object_ids = [DiaObjectId.from_named_tuple(row) for row in catalog.iloc[:10].itertuples()]
719 count = apdb.setValidityEnd(object_ids, self.processing_time)
720 self.assertEqual(count, 10)
722 res = apdb.getDiaObjects(region)
723 self.assert_catalog(res, 90, self.getDiaObjects_table())
725 replica_chunks = apdb_replica.getReplicaChunks()
726 if not self.enable_replica:
727 self.assertIsNone(replica_chunks)
728 else:
729 # Check that there are 10 update records in replica tables.
730 assert replica_chunks is not None
732 # There could be one or two chunks.
733 self.assertTrue(1 <= len(replica_chunks) <= 2)
735 update_records = apdb_replica.getUpdateRecordChunks([chunk.id for chunk in replica_chunks])
736 self.assertEqual(len(update_records), 10)
738 # Check that empty list works.
739 count = apdb.setValidityEnd(object_ids, self.processing_time)
740 self.assertEqual(count, 0)
742 # Try with non-existing object.
743 object_ids = [DiaObjectId.from_named_tuple(row) for row in catalog.iloc[10:12].itertuples()]
744 object_ids += [DiaObjectId(diaObjectId=1_000_000, ra=0.0, dec=0.0)]
745 with self.assertRaises(LookupError):
746 apdb.setValidityEnd(object_ids, self.processing_time, raise_on_missing_id=True)
748 count = apdb.setValidityEnd(object_ids, self.processing_time)
749 self.assertEqual(count, 2)
751 def test_resetDedup(self) -> None:
752 """Test resetDedup method."""
753 # don't care about sources.
754 config = self.make_instance()
755 apdb = Apdb.from_config(config)
757 region = self.make_region()
759 # make catalog with Objects
760 objects = makeObjectCatalog(region, 100)
762 visit_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
763 dedup_time1 = astropy.time.Time("2021-01-01T12:00:00", format="isot", scale="tai")
764 visit_time2 = astropy.time.Time("2021-01-02T00:00:00", format="isot", scale="tai")
765 dedup_time2 = astropy.time.Time("2021-01-02T12:00:00", format="isot", scale="tai")
767 # store catalog
768 apdb.store(visit_time1, objects)
770 catalog = apdb.getDiaObjectsForDedup()
771 self.assertEqual(len(catalog), 100)
773 catalog = apdb.getDiaObjectsForDedup(visit_time1)
774 self.assertEqual(len(catalog), 100)
776 apdb.resetDedup(dedup_time1)
778 catalog = apdb.getDiaObjectsForDedup(visit_time1)
779 self.assertEqual(len(catalog), self._count_after_reset_dedup(100))
781 apdb.store(visit_time2, objects)
783 catalog = apdb.getDiaObjectsForDedup()
784 self.assertEqual(len(catalog), 100)
786 catalog = apdb.getDiaObjectsForDedup(dedup_time1)
787 self.assertEqual(len(catalog), 100)
789 apdb.resetDedup(dedup_time2)
791 catalog = apdb.getDiaObjectsForDedup(dedup_time1)
792 self.assertEqual(len(catalog), self._count_after_reset_dedup(100))
794 catalog = apdb.getDiaObjectsForDedup()
795 self.assertEqual(len(catalog), 0)
797 def _count_after_reset_dedup(self, count_before: int) -> int:
798 """Return the number of rows that will be returned by
799 getDiaObjectsForDedup() after resetDedup() was called. For SQL backend
800 deduplication data comes from a regular table, and it is not removed
801 by resetDedup().
802 """
803 raise NotImplementedError()
805 def test_getChunks(self) -> None:
806 """Store and retrieve replica chunks."""
807 # don't care about sources.
808 config = self.make_instance()
809 apdb = Apdb.from_config(config)
810 apdb_replica = ApdbReplica.from_config(config)
811 visit_time = self.visit_time
813 region1 = self.make_region((1.0, 1.0, -1.0))
814 region2 = self.make_region((-1.0, -1.0, -1.0))
815 nobj = 100
816 objects1 = makeObjectCatalog(region1, nobj)
817 objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2)
819 # With the default 10 minutes replica chunk window we should have 4
820 # records.
821 visits = [
822 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1),
823 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2),
824 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1),
825 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2),
826 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1),
827 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2),
828 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1),
829 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2),
830 ]
832 start_id = 0
833 for visit_time, objects in visits:
834 sources = makeSourceCatalog(objects, visit_time, start_id=start_id, use_mjd=self.use_mjd)
835 fsources = makeForcedSourceCatalog(objects, visit_time, visit=start_id, use_mjd=self.use_mjd)
836 apdb.store(visit_time, objects, sources, fsources)
837 start_id += nobj
839 replica_chunks = apdb_replica.getReplicaChunks()
840 if not self.enable_replica:
841 self.assertIsNone(replica_chunks)
843 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"):
844 apdb_replica.getTableDataChunks(ApdbTables.DiaObject, [])
846 else:
847 assert replica_chunks is not None
848 self.assertEqual(len(replica_chunks), 4)
850 with self.assertRaisesRegex(ValueError, "does not support replica chunks"):
851 apdb_replica.getTableDataChunks(ApdbTables.SSObject, [])
853 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None:
854 if n_records is None:
855 n_records = len(replica_chunks) * nobj
856 res = apdb_replica.getTableDataChunks(
857 ApdbTables.DiaObject, (chunk.id for chunk in replica_chunks)
858 )
859 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
860 validityStartColumn = "validityStartMjdTai" if self.use_mjd else "validityStart"
861 validityStartType = (
862 felis.datamodel.DataType.double if self.use_mjd else felis.datamodel.DataType.timestamp
863 )
864 self.assert_column_types(
865 res,
866 {
867 "apdb_replica_chunk": felis.datamodel.DataType.long,
868 "diaObjectId": felis.datamodel.DataType.long,
869 validityStartColumn: validityStartType,
870 "ra": felis.datamodel.DataType.double,
871 "dec": felis.datamodel.DataType.double,
872 "parallax": felis.datamodel.DataType.float,
873 "nDiaSources": felis.datamodel.DataType.int,
874 },
875 )
877 res = apdb_replica.getTableDataChunks(
878 ApdbTables.DiaSource, (chunk.id for chunk in replica_chunks)
879 )
880 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
881 self.assert_column_types(
882 res,
883 {
884 "apdb_replica_chunk": felis.datamodel.DataType.long,
885 "diaSourceId": felis.datamodel.DataType.long,
886 "visit": felis.datamodel.DataType.long,
887 "detector": felis.datamodel.DataType.short,
888 },
889 )
891 res = apdb_replica.getTableDataChunks(
892 ApdbTables.DiaForcedSource, (chunk.id for chunk in replica_chunks)
893 )
894 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
895 self.assert_column_types(
896 res,
897 {
898 "apdb_replica_chunk": felis.datamodel.DataType.long,
899 "diaObjectId": felis.datamodel.DataType.long,
900 "visit": felis.datamodel.DataType.long,
901 "detector": felis.datamodel.DataType.short,
902 },
903 )
905 # read it back and check sizes
906 _check_chunks(replica_chunks, 800)
907 _check_chunks(replica_chunks[1:], 600)
908 _check_chunks(replica_chunks[1:-1], 400)
909 _check_chunks(replica_chunks[2:3], 200)
910 _check_chunks([])
912 # try to remove some of those
913 deleted_chunks = replica_chunks[:1]
914 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks)
916 # All queries on deleted ids should return empty set.
917 _check_chunks(deleted_chunks, 0)
919 replica_chunks = apdb_replica.getReplicaChunks()
920 assert replica_chunks is not None
921 self.assertEqual(len(replica_chunks), 3)
923 _check_chunks(replica_chunks, 600)
925 def test_reassignObjects(self) -> None:
926 """Reassign DiaObjects."""
927 # don't care about sources.
928 config = self.make_instance()
929 apdb = Apdb.from_config(config)
931 region = self.make_region()
932 visit_time = self.visit_time
933 objects = makeObjectCatalog(region, 100)
934 oids = list(objects["diaObjectId"])
935 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
936 apdb.store(visit_time, objects, sources)
938 # read it back and filter by ID
939 res = apdb.getDiaSources(region, oids, visit_time)
940 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
942 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
943 res = apdb.getDiaSources(region, oids, visit_time)
944 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
946 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
947 apdb.reassignDiaSources(
948 {
949 1000: 1,
950 7: 3,
951 }
952 )
953 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
955 def test_storeUpdateRecord(self) -> None:
956 """Test _storeUpdateRecord() method."""
957 config = self.make_instance()
958 apdb = Apdb.from_config(config)
960 # Times are totally arbitrary.
961 update_time_ns1 = 2_000_000_000_000_000_000
962 update_time_ns2 = 2_000_000_001_000_000_000
963 records = [
964 ApdbReassignDiaSourceToSSObjectRecord(
965 update_time_ns=update_time_ns1,
966 update_order=0,
967 diaSourceId=1,
968 ssObjectId=1,
969 ssObjectReassocTimeMjdTai=60000.0,
970 ra=45.0,
971 dec=-45.0,
972 midpointMjdTai=60000.0,
973 ),
974 ApdbWithdrawDiaSourceRecord(
975 update_time_ns=update_time_ns1,
976 update_order=1,
977 diaSourceId=123456,
978 timeWithdrawnMjdTai=61000.0,
979 ra=45.0,
980 dec=-45.0,
981 midpointMjdTai=60000.0,
982 ),
983 ApdbReassignDiaSourceToSSObjectRecord(
984 update_time_ns=update_time_ns1,
985 update_order=3,
986 diaSourceId=2,
987 ssObjectId=3,
988 ssObjectReassocTimeMjdTai=60000.0,
989 ra=45.0,
990 dec=-45.0,
991 midpointMjdTai=60000.0,
992 ),
993 ApdbWithdrawDiaSourceRecord(
994 update_time_ns=update_time_ns2,
995 update_order=0,
996 diaSourceId=123456,
997 timeWithdrawnMjdTai=61000.0,
998 ra=45.0,
999 dec=-45.0,
1000 midpointMjdTai=60000.0,
1001 ),
1002 ]
1004 update_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
1005 chunk = ReplicaChunk.make_replica_chunk(update_time, 600)
1007 if not self.enable_replica:
1008 with self.assertRaises(TypeError):
1009 self.store_update_records(apdb, records, chunk)
1010 else:
1011 self.store_update_records(apdb, records, chunk)
1013 apdb_replica = ApdbReplica.from_config(config)
1014 records_returned = apdb_replica.getUpdateRecordChunks([chunk.id])
1016 # Input records are ordered, output will be ordered too.
1017 self.assertEqual(records_returned, records)
1019 @abstractmethod
1020 def store_update_records(self, apdb: Apdb, records: list[ApdbUpdateRecord], chunk: ReplicaChunk) -> None:
1021 """Store update records in database, must be overriden in subclass."""
1022 raise NotImplementedError()
1024 def test_midpointMjdTai_src(self) -> None:
1025 """Test for time filtering of DiaSources."""
1026 config = self.make_instance()
1027 apdb = Apdb.from_config(config)
1029 region = self.make_region()
1030 # 2021-01-01 plus 360 days is 2021-12-27
1031 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
1032 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
1033 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
1034 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
1035 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
1036 one_sec = astropy.time.TimeDelta(1.0, format="sec")
1038 objects = makeObjectCatalog(region, 100)
1039 oids = list(objects["diaObjectId"])
1040 sources = makeSourceCatalog(objects, src_time1, 0, use_mjd=self.use_mjd)
1041 apdb.store(src_time1, objects, sources)
1043 sources = makeSourceCatalog(objects, src_time2, 100, use_mjd=self.use_mjd)
1044 apdb.store(src_time2, objects, sources)
1046 # reading at time of last save should read all
1047 res = apdb.getDiaSources(region, oids, src_time2)
1048 self.assert_catalog(res, 200, ApdbTables.DiaSource)
1050 # one second before 12 months
1051 res = apdb.getDiaSources(region, oids, visit_time0)
1052 self.assert_catalog(res, 200, ApdbTables.DiaSource)
1054 # reading at later time of last save should only read a subset
1055 res = apdb.getDiaSources(region, oids, visit_time1)
1056 self.assert_catalog(res, 100, ApdbTables.DiaSource)
1058 # reading at later time of last save should only read a subset
1059 res = apdb.getDiaSources(region, oids, visit_time2)
1060 self.assert_catalog(res, 0, ApdbTables.DiaSource)
1062 # Use explicit start time argument instead of 12 month window, visit
1063 # time does not matter in this case, set it to before all data.
1064 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time1 - one_sec)
1065 self.assert_catalog(res, 200, ApdbTables.DiaSource)
1067 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 - one_sec)
1068 self.assert_catalog(res, 100, ApdbTables.DiaSource)
1070 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 + one_sec)
1071 self.assert_catalog(res, 0, ApdbTables.DiaSource)
1073 def test_midpointMjdTai_fsrc(self) -> None:
1074 """Test for time filtering of DiaForcedSources."""
1075 config = self.make_instance()
1076 apdb = Apdb.from_config(config)
1078 region = self.make_region()
1079 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
1080 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
1081 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
1082 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
1083 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
1084 one_sec = astropy.time.TimeDelta(1.0, format="sec")
1086 objects = makeObjectCatalog(region, 100)
1087 oids = list(objects["diaObjectId"])
1088 sources = makeForcedSourceCatalog(objects, src_time1, 1, use_mjd=self.use_mjd)
1089 apdb.store(src_time1, objects, forced_sources=sources)
1091 sources = makeForcedSourceCatalog(objects, src_time2, 2, use_mjd=self.use_mjd)
1092 apdb.store(src_time2, objects, forced_sources=sources)
1094 # reading at time of last save should read all
1095 res = apdb.getDiaForcedSources(region, oids, src_time2)
1096 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
1098 # one second before 12 months
1099 res = apdb.getDiaForcedSources(region, oids, visit_time0)
1100 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
1102 # reading at later time of last save should only read a subset
1103 res = apdb.getDiaForcedSources(region, oids, visit_time1)
1104 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
1106 # reading at later time of last save should only read a subset
1107 res = apdb.getDiaForcedSources(region, oids, visit_time2)
1108 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
1110 # Use explicit start time argument instead of 12 month window, visit
1111 # time does not matter in this case, set it to before all data.
1112 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time1 - one_sec)
1113 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
1115 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 - one_sec)
1116 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
1118 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 + one_sec)
1119 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
1121 def test_metadata(self) -> None:
1122 """Simple test for writing/reading metadata table"""
1123 config = self.make_instance()
1124 apdb = Apdb.from_config(config)
1125 metadata = apdb.metadata
1127 # APDB should write two or three metadata items with version numbers
1128 # and a frozen JSON config.
1129 self.assertFalse(metadata.empty())
1130 self.assertEqual(len(list(metadata.items())), self.meta_row_count)
1132 metadata.set("meta", "data")
1133 metadata.set("data", "meta")
1135 self.assertFalse(metadata.empty())
1136 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")})
1138 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"):
1139 metadata.set("meta", "data1")
1141 metadata.set("meta", "data2", force=True)
1142 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")})
1144 self.assertTrue(metadata.delete("meta"))
1145 self.assertIsNone(metadata.get("meta"))
1146 self.assertFalse(metadata.delete("meta"))
1148 self.assertEqual(metadata.get("data"), "meta")
1149 self.assertEqual(metadata.get("meta", "meta"), "meta")
1151 def test_schemaVersionFromYaml(self) -> None:
1152 """Check version number handling for reading schema from YAML."""
1153 config = self.make_instance()
1154 default_schema = config.schema_file
1155 apdb = Apdb.from_config(config)
1156 self.assertEqual(apdb.schema.schemaVersion(), VersionTuple(0, 1, 1))
1158 with update_schema_yaml(default_schema, version="") as schema_file:
1159 config = self.make_instance(schema_file=schema_file)
1160 apdb = Apdb.from_config(config)
1161 self.assertEqual(
1162 apdb.schema.schemaVersion(),
1163 VersionTuple(0, 1, 0),
1164 )
1166 with update_schema_yaml(default_schema, version="99.0.0") as schema_file:
1167 config = self.make_instance(schema_file=schema_file)
1168 apdb = Apdb.from_config(config)
1169 self.assertEqual(
1170 apdb.schema.schemaVersion(),
1171 VersionTuple(99, 0, 0),
1172 )
1174 def test_config_freeze(self) -> None:
1175 """Test that some config fields are correctly frozen in database."""
1176 config = self.make_instance()
1178 # `enable_replica` is the only parameter that is frozen in all
1179 # implementations.
1180 config.enable_replica = not self.enable_replica
1181 apdb = Apdb.from_config(config)
1182 frozen_config = apdb.getConfig()
1183 self.assertEqual(frozen_config.enable_replica, self.enable_replica)
1186class ApdbSchemaUpdateTest(TestCaseMixin, ABC):
1187 """Base class for unit tests that verify how schema changes work."""
1189 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
1191 @abstractmethod
1192 def make_instance(self, **kwargs: Any) -> ApdbConfig:
1193 """Make config class instance used in all tests.
1195 This method should return configuration that point to the identical
1196 database instance on each call (i.e. ``db_url`` must be the same,
1197 which also means for sqlite it has to use on-disk storage).
1198 """
1199 raise NotImplementedError()
1201 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
1202 """Make a region to use in tests"""
1203 return _make_region(xyz)
1205 def test_schema_add_replica(self) -> None:
1206 """Check that new code can work with old schema without replica
1207 tables.
1208 """
1209 # Make schema without replica tables.
1210 config = self.make_instance(enable_replica=False)
1211 apdb = Apdb.from_config(config)
1212 apdb_replica = ApdbReplica.from_config(config)
1214 # Make APDB instance configured for replication.
1215 config.enable_replica = True
1216 apdb = Apdb.from_config(config)
1218 # Try to insert something, should work OK.
1219 region = self.make_region()
1220 visit_time = self.visit_time
1222 # have to store Objects first
1223 objects = makeObjectCatalog(region, 100)
1224 sources = makeSourceCatalog(objects, visit_time)
1225 fsources = makeForcedSourceCatalog(objects, visit_time)
1226 apdb.store(visit_time, objects, sources, fsources)
1228 # There should be no replica chunks.
1229 replica_chunks = apdb_replica.getReplicaChunks()
1230 self.assertIsNone(replica_chunks)
1232 def test_schemaVersionCheck(self) -> None:
1233 """Check version number compatibility."""
1234 config = self.make_instance()
1235 apdb = Apdb.from_config(config)
1237 self.assertEqual(apdb.schema.schemaVersion(), VersionTuple(0, 1, 1))
1239 # Claim that schema version is now 99.0.0, must raise an exception.
1240 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file:
1241 config.schema_file = schema_file
1242 with self.assertRaises(IncompatibleVersionError):
1243 apdb = Apdb.from_config(config)
1244 # Version is checked only when we try to do connect.
1245 apdb.metadata.items()