Coverage for python / lsst / dax / apdb / tests / _apdb.py: 15%
441 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:49 +0000
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"]
26import contextlib
27import os
28import tempfile
29from abc import ABC, abstractmethod
30from collections.abc import Iterator
31from tempfile import TemporaryDirectory
32from typing import TYPE_CHECKING, Any
34import astropy.time
35import felis.datamodel
36import pandas
37import yaml
39from lsst.sphgeom import Angle, Circle, LonLat, Region, UnitVector3d
41from .. import (
42 Apdb,
43 ApdbConfig,
44 ApdbReassignDiaSourceRecord,
45 ApdbReplica,
46 ApdbTableData,
47 ApdbTables,
48 ApdbUpdateRecord,
49 ApdbWithdrawDiaSourceRecord,
50 IncompatibleVersionError,
51 ReplicaChunk,
52 VersionTuple,
53)
54from .data_factory import (
55 makeForcedSourceCatalog,
56 makeObjectCatalog,
57 makeSourceCatalog,
58 makeTimestampNow,
59)
60from .utils import TestCaseMixin
62if TYPE_CHECKING:
63 from ..pixelization import Pixelization
66def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
67 """Make a region to use in tests"""
68 pointing_v = UnitVector3d(*xyz)
69 fov = 0.0013 # radians
70 region = Circle(pointing_v, Angle(fov / 2))
71 return region
74@contextlib.contextmanager
75def update_schema_yaml(
76 schema_file: str,
77 drop_metadata: bool = False,
78 version: str | None = None,
79) -> Iterator[str]:
80 """Update schema definition and return name of the new schema file.
82 Parameters
83 ----------
84 schema_file : `str`
85 Path for the existing YAML file with APDB schema.
86 drop_metadata : `bool`
87 If `True` then remove metadata table from the list of tables.
88 version : `str` or `None`
89 If non-empty string then set schema version to this string, if empty
90 string then remove schema version from config, if `None` - don't change
91 the version in config.
93 Yields
94 ------
95 Path for the updated configuration file.
96 """
97 with open(schema_file) as yaml_stream:
98 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
99 # Edit YAML contents.
100 for schema in schemas_list:
101 # Optionally drop metadata table.
102 if drop_metadata:
103 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"]
104 if version is not None:
105 if version == "":
106 del schema["version"]
107 else:
108 schema["version"] = version
110 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
111 output_path = os.path.join(tmpdir, "schema.yaml")
112 with open(output_path, "w") as yaml_stream:
113 yaml.dump_all(schemas_list, stream=yaml_stream)
114 yield output_path
117class ApdbTest(TestCaseMixin, ABC):
118 """Base class for Apdb tests that can be specialized for concrete
119 implementation.
121 This can only be used as a mixin class for a unittest.TestCase and it
122 calls various assert methods.
123 """
125 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
127 fsrc_requires_id_list = False
128 """Should be set to True if getDiaForcedSources requires object IDs"""
130 enable_replica: bool = False
131 """Set to true when support for replication is configured"""
133 use_mjd: bool = True
134 """If True then timestamp columns are MJD TAI."""
136 extra_chunk_columns = 1
137 """Number of additional columns in chunk tables."""
139 meta_row_count = 3
140 """Initial row count in metadata table."""
142 # number of columns as defined in tests/config/schema.yaml
143 table_column_count = {
144 ApdbTables.DiaObject: 7,
145 ApdbTables.DiaObjectLast: 5,
146 ApdbTables.DiaSource: 12,
147 ApdbTables.DiaForcedSource: 8,
148 ApdbTables.SSObject: 3,
149 }
151 @abstractmethod
152 def make_instance(self, **kwargs: Any) -> ApdbConfig:
153 """Make database instance and return configuration for it."""
154 raise NotImplementedError()
156 @abstractmethod
157 def getDiaObjects_table(self) -> ApdbTables:
158 """Return type of table returned from getDiaObjects method."""
159 raise NotImplementedError()
161 @abstractmethod
162 def pixelization(self, config: ApdbConfig) -> Pixelization:
163 """Return pixelization used by implementation."""
164 raise NotImplementedError()
166 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None:
167 """Validate catalog type and size
169 Parameters
170 ----------
171 catalog : `object`
172 Expected type of this is ``pandas.DataFrame``.
173 rows : `int`
174 Expected number of rows in a catalog.
175 table : `ApdbTables`
176 APDB table type.
177 """
178 self.assertIsInstance(catalog, pandas.DataFrame)
179 self.assertEqual(catalog.shape[0], rows)
180 self.assertEqual(catalog.shape[1], self.table_column_count[table])
182 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None:
183 """Validate catalog type and size
185 Parameters
186 ----------
187 catalog : `object`
188 Expected type of this is `ApdbTableData`.
189 rows : `int`
190 Expected number of rows in a catalog.
191 table : `ApdbTables`
192 APDB table type.
193 extra_columns : `int`
194 Count of additional columns expected in ``catalog``.
195 """
196 self.assertIsInstance(catalog, ApdbTableData)
197 n_rows = sum(1 for row in catalog.rows())
198 self.assertEqual(n_rows, rows)
199 # One extra column for replica chunk id
200 self.assertEqual(
201 len(catalog.column_names()), self.table_column_count[table] + self.extra_chunk_columns
202 )
204 def assert_column_types(self, catalog: Any, types: dict[str, felis.datamodel.DataType]) -> None:
205 column_defs = dict(catalog.column_defs())
206 for column, datatype in types.items():
207 self.assertEqual(column_defs[column], datatype)
209 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
210 """Make a region to use in tests"""
211 return _make_region(xyz)
213 def test_makeSchema(self) -> None:
214 """Test for making APDB schema."""
215 config = self.make_instance()
216 apdb = Apdb.from_config(config)
218 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
219 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
220 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
221 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
222 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
223 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject))
224 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource))
225 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match))
227 # Test from_uri factory method with the same config.
228 with tempfile.NamedTemporaryFile() as tmpfile:
229 config.save(tmpfile.name)
230 apdb = Apdb.from_uri(tmpfile.name)
232 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject))
233 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast))
234 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource))
235 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource))
236 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata))
237 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject))
238 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource))
239 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match))
241 def test_empty_gets(self) -> None:
242 """Test for getting data from empty database.
244 All get() methods should return empty results, only useful for
245 checking that code is not broken.
246 """
247 # use non-zero months for Forced/Source fetching
248 config = self.make_instance()
249 apdb = Apdb.from_config(config)
251 region = self.make_region()
252 visit_time = self.visit_time
254 res: pandas.DataFrame | None
256 # get objects by region
257 res = apdb.getDiaObjects(region)
258 self.assert_catalog(res, 0, self.getDiaObjects_table())
260 # get sources by region
261 res = apdb.getDiaSources(region, None, visit_time)
262 self.assert_catalog(res, 0, ApdbTables.DiaSource)
264 res = apdb.getDiaSources(region, [], visit_time)
265 self.assert_catalog(res, 0, ApdbTables.DiaSource)
267 # get sources by object ID, non-empty object list
268 res = apdb.getDiaSources(region, [1, 2, 3], visit_time)
269 self.assert_catalog(res, 0, ApdbTables.DiaSource)
271 # get forced sources by object ID, empty object list
272 res = apdb.getDiaForcedSources(region, [], visit_time)
273 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
275 # get sources by object ID, non-empty object list
276 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time)
277 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
279 # data_factory's ccdVisitId generation corresponds to (1, 1)
280 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
281 self.assertFalse(res)
283 # get sources by region
284 if self.fsrc_requires_id_list:
285 with self.assertRaises(NotImplementedError):
286 apdb.getDiaForcedSources(region, None, visit_time)
287 else:
288 res = apdb.getDiaForcedSources(region, None, visit_time)
289 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
291 def test_empty_gets_0months(self) -> None:
292 """Test for getting data from empty database.
294 All get() methods should return empty DataFrame or None.
295 """
296 # set read_sources_months to 0 so that Forced/Sources are None
297 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0)
298 apdb = Apdb.from_config(config)
300 region = self.make_region()
301 visit_time = self.visit_time
303 res: pandas.DataFrame | None
305 # get objects by region
306 res = apdb.getDiaObjects(region)
307 self.assert_catalog(res, 0, self.getDiaObjects_table())
309 # get sources by region
310 res = apdb.getDiaSources(region, None, visit_time)
311 self.assertIs(res, None)
313 # get sources by object ID, empty object list
314 res = apdb.getDiaSources(region, [], visit_time)
315 self.assertIs(res, None)
317 # get forced sources by object ID, empty object list
318 res = apdb.getDiaForcedSources(region, [], visit_time)
319 self.assertIs(res, None)
321 # Database is empty, no images exist.
322 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
323 self.assertFalse(res)
325 def test_storeObjects(self) -> None:
326 """Store and retrieve DiaObjects."""
327 # don't care about sources.
328 config = self.make_instance()
329 apdb = Apdb.from_config(config)
331 region = self.make_region()
332 visit_time = self.visit_time
334 # make catalog with Objects
335 catalog = makeObjectCatalog(region, 100, visit_time)
337 # store catalog
338 apdb.store(visit_time, catalog)
340 # read it back and check sizes
341 res = apdb.getDiaObjects(region)
342 self.assert_catalog(res, len(catalog), self.getDiaObjects_table())
344 # TODO: test apdb.contains with generic implementation from DM-41671
346 def test_storeObjects_empty(self) -> None:
347 """Test calling storeObject when there are no objects: see DM-43270."""
348 config = self.make_instance()
349 apdb = Apdb.from_config(config)
350 region = self.make_region()
351 visit_time = self.visit_time
352 # make catalog with no Objects
353 catalog = makeObjectCatalog(region, 0, visit_time)
355 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm:
356 apdb.store(visit_time, catalog)
357 self.assertIn("No objects", "\n".join(cm.output))
359 def test_storeMovingObject(self) -> None:
360 """Store and retrieve DiaObject which changes its position."""
361 # don't care about sources.
362 config = self.make_instance()
363 apdb = Apdb.from_config(config)
364 pixelization = self.pixelization(config)
366 lon_deg, lat_deg = 0.0, 0.0
367 lonlat1 = LonLat.fromDegrees(lon_deg - 1.0, lat_deg)
368 lonlat2 = LonLat.fromDegrees(lon_deg + 1.0, lat_deg)
369 uv1 = UnitVector3d(lonlat1)
370 uv2 = UnitVector3d(lonlat2)
372 # Check that they fall into different pixels.
373 self.assertNotEqual(pixelization.pixel(uv1), pixelization.pixel(uv2))
375 # Store one object at two different positions.
376 visit_time1 = self.visit_time
377 catalog1 = makeObjectCatalog(lonlat1, 1, visit_time1)
378 apdb.store(visit_time1, catalog1)
380 visit_time2 = visit_time1 + astropy.time.TimeDelta(120.0, format="sec")
381 catalog1 = makeObjectCatalog(lonlat2, 1, visit_time2)
382 apdb.store(visit_time2, catalog1)
384 # Make region covering both points.
385 region = Circle(UnitVector3d(LonLat.fromDegrees(lon_deg, lat_deg)), Angle.fromDegrees(1.1))
386 self.assertTrue(region.contains(uv1))
387 self.assertTrue(region.contains(uv2))
389 # Read it back, must return the latest one.
390 res = apdb.getDiaObjects(region)
391 self.assert_catalog(res, 1, self.getDiaObjects_table())
393 def test_storeSources(self) -> None:
394 """Store and retrieve DiaSources."""
395 config = self.make_instance()
396 apdb = Apdb.from_config(config)
398 region = self.make_region()
399 visit_time = self.visit_time
401 # have to store Objects first
402 objects = makeObjectCatalog(region, 100, visit_time)
403 oids = list(objects["diaObjectId"])
404 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
406 # save the objects and sources
407 apdb.store(visit_time, objects, sources)
409 # read it back, no ID filtering
410 res = apdb.getDiaSources(region, None, visit_time)
411 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
413 # read it back and filter by ID
414 res = apdb.getDiaSources(region, oids, visit_time)
415 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
417 # read it back to get schema
418 res = apdb.getDiaSources(region, [], visit_time)
419 self.assert_catalog(res, 0, ApdbTables.DiaSource)
421 # test if a visit is present
422 # data_factory's ccdVisitId generation corresponds to (1, 1)
423 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
424 self.assertTrue(res)
425 # non-existent image
426 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time)
427 self.assertFalse(res)
429 def test_storeForcedSources(self) -> None:
430 """Store and retrieve DiaForcedSources."""
431 config = self.make_instance()
432 apdb = Apdb.from_config(config)
434 region = self.make_region()
435 visit_time = self.visit_time
437 # have to store Objects first
438 objects = makeObjectCatalog(region, 100, visit_time)
439 oids = list(objects["diaObjectId"])
440 catalog = makeForcedSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
442 apdb.store(visit_time, objects, forced_sources=catalog)
444 # read it back and check sizes
445 res = apdb.getDiaForcedSources(region, oids, visit_time)
446 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
448 # read it back to get schema
449 res = apdb.getDiaForcedSources(region, [], visit_time)
450 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
452 # data_factory's ccdVisitId generation corresponds to (1, 1)
453 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time)
454 self.assertTrue(res)
455 # non-existent image
456 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time)
457 self.assertFalse(res)
459 def test_timestamps(self) -> None:
460 """Check that timestamp return type is as expected."""
461 config = self.make_instance()
462 apdb = Apdb.from_config(config)
464 region = self.make_region()
465 visit_time = self.visit_time
467 # Cassandra has a millisecond precision, so subtract 1ms to allow for
468 # truncated returned values.
469 time_before = makeTimestampNow(self.use_mjd, -1)
470 objects = makeObjectCatalog(region, 100, visit_time)
471 oids = list(objects["diaObjectId"])
472 catalog = makeForcedSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
473 time_after = makeTimestampNow(self.use_mjd)
475 apdb.store(visit_time, objects, forced_sources=catalog)
477 # read it back and check sizes
478 res = apdb.getDiaForcedSources(region, oids, visit_time)
479 assert res is not None
480 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource)
482 time_processed_column = "timeProcessedMjdTai" if self.use_mjd else "time_processed"
483 self.assertIn(time_processed_column, res.dtypes)
484 dtype = res.dtypes[time_processed_column]
485 timestamp_type_name = "float64" if self.use_mjd else "datetime64[ns]"
486 self.assertEqual(dtype.name, timestamp_type_name)
487 # Verify that returned time is sensible.
488 self.assertTrue(all(time_before <= dt <= time_after for dt in res[time_processed_column]))
490 def test_getChunks(self) -> None:
491 """Store and retrieve replica chunks."""
492 # don't care about sources.
493 config = self.make_instance()
494 apdb = Apdb.from_config(config)
495 apdb_replica = ApdbReplica.from_config(config)
496 visit_time = self.visit_time
498 region1 = self.make_region((1.0, 1.0, -1.0))
499 region2 = self.make_region((-1.0, -1.0, -1.0))
500 nobj = 100
501 objects1 = makeObjectCatalog(region1, nobj, visit_time)
502 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2)
504 # With the default 10 minutes replica chunk window we should have 4
505 # records.
506 visits = [
507 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1),
508 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2),
509 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1),
510 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2),
511 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1),
512 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2),
513 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1),
514 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2),
515 ]
517 start_id = 0
518 for visit_time, objects in visits:
519 sources = makeSourceCatalog(objects, visit_time, start_id=start_id, use_mjd=self.use_mjd)
520 fsources = makeForcedSourceCatalog(objects, visit_time, visit=start_id, use_mjd=self.use_mjd)
521 apdb.store(visit_time, objects, sources, fsources)
522 start_id += nobj
524 replica_chunks = apdb_replica.getReplicaChunks()
525 if not self.enable_replica:
526 self.assertIsNone(replica_chunks)
528 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"):
529 apdb_replica.getTableDataChunks(ApdbTables.DiaObject, [])
531 else:
532 assert replica_chunks is not None
533 self.assertEqual(len(replica_chunks), 4)
535 with self.assertRaisesRegex(ValueError, "does not support replica chunks"):
536 apdb_replica.getTableDataChunks(ApdbTables.SSObject, [])
538 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None:
539 if n_records is None:
540 n_records = len(replica_chunks) * nobj
541 res = apdb_replica.getTableDataChunks(
542 ApdbTables.DiaObject, (chunk.id for chunk in replica_chunks)
543 )
544 self.assert_table_data(res, n_records, ApdbTables.DiaObject)
545 validityStartColumn = "validityStartMjdTai" if self.use_mjd else "validityStart"
546 validityStartType = (
547 felis.datamodel.DataType.double if self.use_mjd else felis.datamodel.DataType.timestamp
548 )
549 self.assert_column_types(
550 res,
551 {
552 "apdb_replica_chunk": felis.datamodel.DataType.long,
553 "diaObjectId": felis.datamodel.DataType.long,
554 validityStartColumn: validityStartType,
555 "ra": felis.datamodel.DataType.double,
556 "dec": felis.datamodel.DataType.double,
557 "parallax": felis.datamodel.DataType.float,
558 "nDiaSources": felis.datamodel.DataType.int,
559 },
560 )
562 res = apdb_replica.getTableDataChunks(
563 ApdbTables.DiaSource, (chunk.id for chunk in replica_chunks)
564 )
565 self.assert_table_data(res, n_records, ApdbTables.DiaSource)
566 self.assert_column_types(
567 res,
568 {
569 "apdb_replica_chunk": felis.datamodel.DataType.long,
570 "diaSourceId": felis.datamodel.DataType.long,
571 "visit": felis.datamodel.DataType.long,
572 "detector": felis.datamodel.DataType.short,
573 },
574 )
576 res = apdb_replica.getTableDataChunks(
577 ApdbTables.DiaForcedSource, (chunk.id for chunk in replica_chunks)
578 )
579 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource)
580 self.assert_column_types(
581 res,
582 {
583 "apdb_replica_chunk": felis.datamodel.DataType.long,
584 "diaObjectId": felis.datamodel.DataType.long,
585 "visit": felis.datamodel.DataType.long,
586 "detector": felis.datamodel.DataType.short,
587 },
588 )
590 # read it back and check sizes
591 _check_chunks(replica_chunks, 800)
592 _check_chunks(replica_chunks[1:], 600)
593 _check_chunks(replica_chunks[1:-1], 400)
594 _check_chunks(replica_chunks[2:3], 200)
595 _check_chunks([])
597 # try to remove some of those
598 deleted_chunks = replica_chunks[:1]
599 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks)
601 # All queries on deleted ids should return empty set.
602 _check_chunks(deleted_chunks, 0)
604 replica_chunks = apdb_replica.getReplicaChunks()
605 assert replica_chunks is not None
606 self.assertEqual(len(replica_chunks), 3)
608 _check_chunks(replica_chunks, 600)
610 def test_reassignObjects(self) -> None:
611 """Reassign DiaObjects."""
612 # don't care about sources.
613 config = self.make_instance()
614 apdb = Apdb.from_config(config)
616 region = self.make_region()
617 visit_time = self.visit_time
618 objects = makeObjectCatalog(region, 100, visit_time)
619 oids = list(objects["diaObjectId"])
620 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd)
621 apdb.store(visit_time, objects, sources)
623 # read it back and filter by ID
624 res = apdb.getDiaSources(region, oids, visit_time)
625 self.assert_catalog(res, len(sources), ApdbTables.DiaSource)
627 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5})
628 res = apdb.getDiaSources(region, oids, visit_time)
629 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
631 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"):
632 apdb.reassignDiaSources(
633 {
634 1000: 1,
635 7: 3,
636 }
637 )
638 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource)
640 def test_storeUpdateRecord(self) -> None:
641 """Test _storeUpdateRecord() method."""
642 config = self.make_instance()
643 apdb = Apdb.from_config(config)
645 # Times are totally arbitrary.
646 update_time_ns1 = 2_000_000_000_000_000_000
647 update_time_ns2 = 2_000_000_001_000_000_000
648 records = [
649 ApdbReassignDiaSourceRecord(
650 update_time_ns=update_time_ns1,
651 update_order=0,
652 diaSourceId=1,
653 diaObjectId=321,
654 ssObjectId=1,
655 ssObjectReassocTimeMjdTai=60000.0,
656 ra=45.0,
657 dec=-45.0,
658 ),
659 ApdbWithdrawDiaSourceRecord(
660 update_time_ns=update_time_ns1,
661 update_order=1,
662 diaSourceId=123456,
663 diaObjectId=321,
664 timeWithdrawnMjdTai=61000.0,
665 ra=45.0,
666 dec=-45.0,
667 ),
668 ApdbReassignDiaSourceRecord(
669 update_time_ns=update_time_ns1,
670 update_order=3,
671 diaSourceId=2,
672 diaObjectId=3,
673 ssObjectId=3,
674 ssObjectReassocTimeMjdTai=60000.0,
675 ra=45.0,
676 dec=-45.0,
677 ),
678 ApdbWithdrawDiaSourceRecord(
679 update_time_ns=update_time_ns2,
680 update_order=0,
681 diaSourceId=123456,
682 diaObjectId=321,
683 timeWithdrawnMjdTai=61000.0,
684 ra=45.0,
685 dec=-45.0,
686 ),
687 ]
689 update_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
690 chunk = ReplicaChunk.make_replica_chunk(update_time, 600)
692 if not self.enable_replica:
693 with self.assertRaises(TypeError):
694 self.store_update_records(apdb, records, chunk)
695 else:
696 self.store_update_records(apdb, records, chunk)
698 apdb_replica = ApdbReplica.from_config(config)
699 records_returned = apdb_replica.getUpdateRecordChunks([chunk.id])
701 # Input records are ordered, output will be ordered too.
702 self.assertEqual(records_returned, records)
704 @abstractmethod
705 def store_update_records(self, apdb: Apdb, records: list[ApdbUpdateRecord], chunk: ReplicaChunk) -> None:
706 """Store update records in database, must be overriden in subclass."""
707 raise NotImplementedError()
709 def test_midpointMjdTai_src(self) -> None:
710 """Test for time filtering of DiaSources."""
711 config = self.make_instance()
712 apdb = Apdb.from_config(config)
714 region = self.make_region()
715 # 2021-01-01 plus 360 days is 2021-12-27
716 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
717 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
718 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
719 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
720 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
721 one_sec = astropy.time.TimeDelta(1.0, format="sec")
723 objects = makeObjectCatalog(region, 100, visit_time0)
724 oids = list(objects["diaObjectId"])
725 sources = makeSourceCatalog(objects, src_time1, 0, use_mjd=self.use_mjd)
726 apdb.store(src_time1, objects, sources)
728 sources = makeSourceCatalog(objects, src_time2, 100, use_mjd=self.use_mjd)
729 apdb.store(src_time2, objects, sources)
731 # reading at time of last save should read all
732 res = apdb.getDiaSources(region, oids, src_time2)
733 self.assert_catalog(res, 200, ApdbTables.DiaSource)
735 # one second before 12 months
736 res = apdb.getDiaSources(region, oids, visit_time0)
737 self.assert_catalog(res, 200, ApdbTables.DiaSource)
739 # reading at later time of last save should only read a subset
740 res = apdb.getDiaSources(region, oids, visit_time1)
741 self.assert_catalog(res, 100, ApdbTables.DiaSource)
743 # reading at later time of last save should only read a subset
744 res = apdb.getDiaSources(region, oids, visit_time2)
745 self.assert_catalog(res, 0, ApdbTables.DiaSource)
747 # Use explicit start time argument instead of 12 month window, visit
748 # time does not matter in this case, set it to before all data.
749 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time1 - one_sec)
750 self.assert_catalog(res, 200, ApdbTables.DiaSource)
752 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 - one_sec)
753 self.assert_catalog(res, 100, ApdbTables.DiaSource)
755 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 + one_sec)
756 self.assert_catalog(res, 0, ApdbTables.DiaSource)
758 def test_midpointMjdTai_fsrc(self) -> None:
759 """Test for time filtering of DiaForcedSources."""
760 config = self.make_instance()
761 apdb = Apdb.from_config(config)
763 region = self.make_region()
764 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
765 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai")
766 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai")
767 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai")
768 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai")
769 one_sec = astropy.time.TimeDelta(1.0, format="sec")
771 objects = makeObjectCatalog(region, 100, visit_time0)
772 oids = list(objects["diaObjectId"])
773 sources = makeForcedSourceCatalog(objects, src_time1, 1, use_mjd=self.use_mjd)
774 apdb.store(src_time1, objects, forced_sources=sources)
776 sources = makeForcedSourceCatalog(objects, src_time2, 2, use_mjd=self.use_mjd)
777 apdb.store(src_time2, objects, forced_sources=sources)
779 # reading at time of last save should read all
780 res = apdb.getDiaForcedSources(region, oids, src_time2)
781 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
783 # one second before 12 months
784 res = apdb.getDiaForcedSources(region, oids, visit_time0)
785 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
787 # reading at later time of last save should only read a subset
788 res = apdb.getDiaForcedSources(region, oids, visit_time1)
789 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
791 # reading at later time of last save should only read a subset
792 res = apdb.getDiaForcedSources(region, oids, visit_time2)
793 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
795 # Use explicit start time argument instead of 12 month window, visit
796 # time does not matter in this case, set it to before all data.
797 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time1 - one_sec)
798 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource)
800 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 - one_sec)
801 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource)
803 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 + one_sec)
804 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource)
806 def test_metadata(self) -> None:
807 """Simple test for writing/reading metadata table"""
808 config = self.make_instance()
809 apdb = Apdb.from_config(config)
810 metadata = apdb.metadata
812 # APDB should write two or three metadata items with version numbers
813 # and a frozen JSON config.
814 self.assertFalse(metadata.empty())
815 self.assertEqual(len(list(metadata.items())), self.meta_row_count)
817 metadata.set("meta", "data")
818 metadata.set("data", "meta")
820 self.assertFalse(metadata.empty())
821 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")})
823 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"):
824 metadata.set("meta", "data1")
826 metadata.set("meta", "data2", force=True)
827 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")})
829 self.assertTrue(metadata.delete("meta"))
830 self.assertIsNone(metadata.get("meta"))
831 self.assertFalse(metadata.delete("meta"))
833 self.assertEqual(metadata.get("data"), "meta")
834 self.assertEqual(metadata.get("meta", "meta"), "meta")
836 def test_schemaVersionFromYaml(self) -> None:
837 """Check version number handling for reading schema from YAML."""
838 config = self.make_instance()
839 default_schema = config.schema_file
840 apdb = Apdb.from_config(config)
841 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined]
843 with update_schema_yaml(default_schema, version="") as schema_file:
844 config = self.make_instance(schema_file=schema_file)
845 apdb = Apdb.from_config(config)
846 self.assertEqual(
847 apdb._schema.schemaVersion(), # type: ignore[attr-defined]
848 VersionTuple(0, 1, 0),
849 )
851 with update_schema_yaml(default_schema, version="99.0.0") as schema_file:
852 config = self.make_instance(schema_file=schema_file)
853 apdb = Apdb.from_config(config)
854 self.assertEqual(
855 apdb._schema.schemaVersion(), # type: ignore[attr-defined]
856 VersionTuple(99, 0, 0),
857 )
859 def test_config_freeze(self) -> None:
860 """Test that some config fields are correctly frozen in database."""
861 config = self.make_instance()
863 # `enable_replica` is the only parameter that is frozen in all
864 # implementations.
865 config.enable_replica = not self.enable_replica
866 apdb = Apdb.from_config(config)
867 frozen_config = apdb.getConfig()
868 self.assertEqual(frozen_config.enable_replica, self.enable_replica)
871class ApdbSchemaUpdateTest(TestCaseMixin, ABC):
872 """Base class for unit tests that verify how schema changes work."""
874 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai")
876 @abstractmethod
877 def make_instance(self, **kwargs: Any) -> ApdbConfig:
878 """Make config class instance used in all tests.
880 This method should return configuration that point to the identical
881 database instance on each call (i.e. ``db_url`` must be the same,
882 which also means for sqlite it has to use on-disk storage).
883 """
884 raise NotImplementedError()
886 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region:
887 """Make a region to use in tests"""
888 return _make_region(xyz)
890 def test_schema_add_replica(self) -> None:
891 """Check that new code can work with old schema without replica
892 tables.
893 """
894 # Make schema without replica tables.
895 config = self.make_instance(enable_replica=False)
896 apdb = Apdb.from_config(config)
897 apdb_replica = ApdbReplica.from_config(config)
899 # Make APDB instance configured for replication.
900 config.enable_replica = True
901 apdb = Apdb.from_config(config)
903 # Try to insert something, should work OK.
904 region = self.make_region()
905 visit_time = self.visit_time
907 # have to store Objects first
908 objects = makeObjectCatalog(region, 100, visit_time)
909 sources = makeSourceCatalog(objects, visit_time)
910 fsources = makeForcedSourceCatalog(objects, visit_time)
911 apdb.store(visit_time, objects, sources, fsources)
913 # There should be no replica chunks.
914 replica_chunks = apdb_replica.getReplicaChunks()
915 self.assertIsNone(replica_chunks)
917 def test_schemaVersionCheck(self) -> None:
918 """Check version number compatibility."""
919 config = self.make_instance()
920 apdb = Apdb.from_config(config)
922 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined]
924 # Claim that schema version is now 99.0.0, must raise an exception.
925 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file:
926 config.schema_file = schema_file
927 with self.assertRaises(IncompatibleVersionError):
928 apdb = Apdb.from_config(config)
929 # Version is checked only when we try to do connect.
930 apdb.metadata.items()