Coverage for tests/test_obscore.py: 17%
307 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-05 01:26 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-05 01:26 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import gc
23import os
24import tempfile
25import unittest
26import warnings
27from abc import abstractmethod
28from typing import cast
30import astropy.time
31import sqlalchemy
32from lsst.daf.butler import (
33 CollectionType,
34 Config,
35 DataCoordinate,
36 DatasetRef,
37 DatasetType,
38 StorageClassFactory,
39)
40from lsst.daf.butler.registries.sql import SqlRegistry
41from lsst.daf.butler.registry import Registry, RegistryConfig, _ButlerRegistry, _RegistryFactory
42from lsst.daf.butler.registry.obscore import (
43 DatasetTypeConfig,
44 ObsCoreConfig,
45 ObsCoreLiveTableManager,
46 ObsCoreSchema,
47)
48from lsst.daf.butler.registry.obscore._schema import _STATIC_COLUMNS
49from lsst.daf.butler.tests.utils import TestCaseMixin, makeTestTempDir, removeTestTempDir
50from lsst.sphgeom import Box, ConvexPolygon, LonLat, UnitVector3d
52try:
53 import testing.postgresql # type: ignore
54except ImportError:
55 testing = None
57TESTDIR = os.path.abspath(os.path.dirname(__file__))
60class ObsCoreTests(TestCaseMixin):
61 """Base class for testing obscore manager functionality."""
63 root: str
65 def make_registry(
66 self, collections: list[str] | None = None, collection_type: str | None = None
67 ) -> _ButlerRegistry:
68 """Create new empty Registry."""
69 config = self.make_registry_config(collections, collection_type)
70 registry = _RegistryFactory(config).create_from_config(butlerRoot=self.root)
71 self.initialize_registry(registry)
72 return registry
74 @abstractmethod
75 def make_registry_config(
76 self, collections: list[str] | None = None, collection_type: str | None = None
77 ) -> RegistryConfig:
78 """Make Registry configuration."""
79 raise NotImplementedError()
81 def initialize_registry(self, registry: Registry) -> None:
82 """Populate Registry with the things that we need for tests."""
83 registry.insertDimensionData("instrument", {"name": "DummyCam"})
84 registry.insertDimensionData(
85 "physical_filter", {"instrument": "DummyCam", "name": "d-r", "band": "r"}
86 )
87 for detector in (1, 2, 3, 4):
88 registry.insertDimensionData(
89 "detector", {"instrument": "DummyCam", "id": detector, "full_name": f"detector{detector}"}
90 )
92 for exposure in (1, 2, 3, 4):
93 registry.insertDimensionData(
94 "exposure",
95 {
96 "instrument": "DummyCam",
97 "id": exposure,
98 "obs_id": f"exposure{exposure}",
99 "physical_filter": "d-r",
100 },
101 )
103 registry.insertDimensionData("visit_system", {"instrument": "DummyCam", "id": 1, "name": "default"})
105 for visit in (1, 2, 3, 4, 9):
106 visit_start = astropy.time.Time(f"2020-01-01 08:0{visit}:00", scale="tai")
107 visit_end = astropy.time.Time(f"2020-01-01 08:0{visit}:45", scale="tai")
108 registry.insertDimensionData(
109 "visit",
110 {
111 "instrument": "DummyCam",
112 "id": visit,
113 "name": f"visit{visit}",
114 "physical_filter": "d-r",
115 "visit_system": 1,
116 "datetime_begin": visit_start,
117 "datetime_end": visit_end,
118 },
119 )
121 # Only couple of exposures are linked to visits.
122 for visit in (1, 2):
123 registry.insertDimensionData(
124 "visit_definition",
125 {
126 "instrument": "DummyCam",
127 "exposure": visit,
128 "visit": visit,
129 },
130 )
132 # map visit and detector to region
133 self.regions: dict[tuple[int, int], ConvexPolygon] = {}
134 for visit in (1, 2, 3, 4):
135 for detector in (1, 2, 3, 4):
136 lon = visit * 90 - 88
137 lat = detector * 2 - 5
138 region = ConvexPolygon(
139 [
140 UnitVector3d(LonLat.fromDegrees(lon - 1.0, lat - 1.0)),
141 UnitVector3d(LonLat.fromDegrees(lon + 1.0, lat - 1.0)),
142 UnitVector3d(LonLat.fromDegrees(lon + 1.0, lat + 1.0)),
143 UnitVector3d(LonLat.fromDegrees(lon - 1.0, lat + 1.0)),
144 ]
145 )
146 registry.insertDimensionData(
147 "visit_detector_region",
148 {
149 "instrument": "DummyCam",
150 "visit": visit,
151 "detector": detector,
152 "region": region,
153 },
154 )
155 self.regions[(visit, detector)] = region
157 # Visit 9 has non-polygon region
158 for detector in (1, 2, 3, 4):
159 lat = detector * 2 - 5
160 region = Box.fromDegrees(17.0, lat - 1.0, 19.0, lat + 1.0)
161 registry.insertDimensionData(
162 "visit_detector_region",
163 {
164 "instrument": "DummyCam",
165 "visit": 9,
166 "detector": detector,
167 "region": region,
168 },
169 )
171 # Add few dataset types
172 storage_class_factory = StorageClassFactory()
173 storage_class = storage_class_factory.getStorageClass("StructuredDataDict")
175 self.dataset_types: dict[str, DatasetType] = {}
177 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector", "exposure"])
178 self.dataset_types["raw"] = DatasetType("raw", dimensions, storage_class)
180 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector", "visit"])
181 self.dataset_types["calexp"] = DatasetType("calexp", dimensions, storage_class)
183 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector", "visit"])
184 self.dataset_types["no_obscore"] = DatasetType("no_obscore", dimensions, storage_class)
186 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector"])
187 self.dataset_types["calib"] = DatasetType("calib", dimensions, storage_class, isCalibration=True)
189 for dataset_type in self.dataset_types.values():
190 registry.registerDatasetType(dataset_type)
192 # Add few run collections.
193 for run in (1, 2, 3, 4, 5, 6):
194 registry.registerRun(f"run{run}")
196 # Add few chained collections, run6 is not in any chained collections.
197 registry.registerCollection("chain12", CollectionType.CHAINED)
198 registry.setCollectionChain("chain12", ("run1", "run2"))
199 registry.registerCollection("chain34", CollectionType.CHAINED)
200 registry.setCollectionChain("chain34", ("run3", "run4"))
201 registry.registerCollection("chain-all", CollectionType.CHAINED)
202 registry.setCollectionChain("chain-all", ("chain12", "chain34", "run5"))
204 # And a tagged collection
205 registry.registerCollection("tagged", CollectionType.TAGGED)
207 def make_obscore_config(
208 self, collections: list[str] | None = None, collection_type: str | None = None
209 ) -> Config:
210 """Make configuration for obscore manager."""
211 obscore_config = Config(os.path.join(TESTDIR, "config", "basic", "obscore.yaml"))
212 if collections is not None:
213 obscore_config["collections"] = collections
214 if collection_type is not None:
215 obscore_config["collection_type"] = collection_type
216 return obscore_config
218 def _insert_dataset(
219 self, registry: Registry, run: str, dataset_type: str, do_import: bool = False, **kwargs
220 ) -> DatasetRef:
221 """Insert or import one dataset into a specified run collection."""
222 data_id = {"instrument": "DummyCam", "physical_filter": "d-r"}
223 data_id.update(kwargs)
224 coordinate = DataCoordinate.standardize(data_id, universe=registry.dimensions)
225 if do_import:
226 ds_type = self.dataset_types[dataset_type]
227 ref = DatasetRef(ds_type, coordinate, run=run)
228 [ref] = registry._importDatasets([ref])
229 else:
230 [ref] = registry.insertDatasets(dataset_type, [data_id], run=run)
231 return ref
233 def _insert_datasets(self, registry: Registry, do_import: bool = False) -> list[DatasetRef]:
234 """Inset a small bunch of datasets into every run collection."""
235 return [
236 self._insert_dataset(registry, "run1", "raw", detector=1, exposure=1, do_import=do_import),
237 self._insert_dataset(registry, "run2", "calexp", detector=2, visit=2, do_import=do_import),
238 self._insert_dataset(registry, "run3", "raw", detector=3, exposure=3, do_import=do_import),
239 self._insert_dataset(registry, "run4", "calexp", detector=4, visit=4, do_import=do_import),
240 self._insert_dataset(registry, "run5", "calexp", detector=4, visit=4, do_import=do_import),
241 # This dataset type is not configured, will not be in obscore.
242 self._insert_dataset(registry, "run5", "no_obscore", detector=1, visit=1, do_import=do_import),
243 self._insert_dataset(registry, "run6", "raw", detector=1, exposure=4, do_import=do_import),
244 ]
246 def test_config_errors(self):
247 """Test for handling various configuration problems."""
248 # This raises pydantic ValidationError, which wraps ValueError
249 exception_re = "'collections' must have one element"
250 with self.assertRaisesRegex(ValueError, exception_re):
251 self.make_registry(None, "TAGGED")
253 with self.assertRaisesRegex(ValueError, exception_re):
254 self.make_registry([], "TAGGED")
256 with self.assertRaisesRegex(ValueError, exception_re):
257 self.make_registry(["run1", "run2"], "TAGGED")
259 # Invalid regex.
260 with self.assertRaisesRegex(ValueError, "Failed to compile regex"):
261 self.make_registry(["+run"], "RUN")
263 def test_schema(self):
264 """Check how obscore schema is constructed"""
265 config = ObsCoreConfig(obs_collection="", dataset_types={}, facility_name="FACILITY")
266 schema = ObsCoreSchema(config, [])
267 table_spec = schema.table_spec
268 self.assertEqual(list(table_spec.fields.names), [col.name for col in _STATIC_COLUMNS])
270 # extra columns from top-level config
271 config = ObsCoreConfig(
272 obs_collection="",
273 extra_columns={"c1": 1, "c2": "string", "c3": {"template": "{calib_level}", "type": "float"}},
274 dataset_types={},
275 facility_name="FACILITY",
276 )
277 schema = ObsCoreSchema(config, [])
278 table_spec = schema.table_spec
279 self.assertEqual(
280 list(table_spec.fields.names),
281 [col.name for col in _STATIC_COLUMNS] + ["c1", "c2", "c3"],
282 )
283 self.assertEqual(table_spec.fields["c1"].dtype, sqlalchemy.BigInteger)
284 self.assertEqual(table_spec.fields["c2"].dtype, sqlalchemy.String)
285 self.assertEqual(table_spec.fields["c3"].dtype, sqlalchemy.Float)
287 # extra columns from per-dataset type configs
288 config = ObsCoreConfig(
289 obs_collection="",
290 extra_columns={"c1": 1},
291 dataset_types={
292 "raw": DatasetTypeConfig(
293 name="raw",
294 dataproduct_type="image",
295 calib_level=1,
296 extra_columns={"c2": "string"},
297 ),
298 "calexp": DatasetTypeConfig(
299 dataproduct_type="image",
300 calib_level=2,
301 extra_columns={"c3": 1e10},
302 ),
303 },
304 facility_name="FACILITY",
305 )
306 schema = ObsCoreSchema(config, [])
307 table_spec = schema.table_spec
308 self.assertEqual(
309 list(table_spec.fields.names),
310 [col.name for col in _STATIC_COLUMNS] + ["c1", "c2", "c3"],
311 )
312 self.assertEqual(table_spec.fields["c1"].dtype, sqlalchemy.BigInteger)
313 self.assertEqual(table_spec.fields["c2"].dtype, sqlalchemy.String)
314 self.assertEqual(table_spec.fields["c3"].dtype, sqlalchemy.Float)
316 # Columns with the same names as in static list in configs, types
317 # are not overriden.
318 config = ObsCoreConfig(
319 version=0,
320 obs_collection="",
321 extra_columns={"t_xel": 1e10},
322 dataset_types={
323 "raw": DatasetTypeConfig(
324 dataproduct_type="image",
325 calib_level=1,
326 extra_columns={"target_name": 1},
327 ),
328 "calexp": DatasetTypeConfig(
329 dataproduct_type="image",
330 calib_level=2,
331 extra_columns={"em_xel": "string"},
332 ),
333 },
334 facility_name="FACILITY",
335 )
336 schema = ObsCoreSchema(config, [])
337 table_spec = schema.table_spec
338 self.assertEqual(list(table_spec.fields.names), [col.name for col in _STATIC_COLUMNS])
339 self.assertEqual(table_spec.fields["t_xel"].dtype, sqlalchemy.Integer)
340 self.assertEqual(table_spec.fields["target_name"].dtype, sqlalchemy.String)
341 self.assertEqual(table_spec.fields["em_xel"].dtype, sqlalchemy.Integer)
343 def test_insert_existing_collection(self):
344 """Test insert and import registry methods, with various restrictions
345 on collection names.
346 """
347 # First item is collections, second item is expected record count.
348 test_data = (
349 (None, 6),
350 (["run1", "run2"], 2),
351 (["run[34]"], 2),
352 (["[rR]un[^6]"], 5),
353 )
355 for collections, count in test_data:
356 for do_import in (False, True):
357 registry = self.make_registry(collections)
358 obscore = registry.obsCoreTableManager
359 assert obscore is not None
360 self._insert_datasets(registry, do_import)
362 with obscore.query() as result:
363 rows = list(result)
364 self.assertEqual(len(rows), count)
366 # Also check `query` method with COUNT(*)
367 with obscore.query([sqlalchemy.sql.func.count()]) as result:
368 scalar = result.scalar_one()
369 self.assertEqual(scalar, count)
371 def test_drop_datasets(self):
372 """Test for dropping datasets after obscore insert."""
373 collections = None
374 registry = self.make_registry(collections)
375 obscore = registry.obsCoreTableManager
376 assert obscore is not None
377 refs = self._insert_datasets(registry)
379 with obscore.query() as result:
380 rows = list(result)
381 self.assertEqual(len(rows), 6)
383 # drop single dataset
384 registry.removeDatasets(ref for ref in refs if ref.run == "run1")
385 with obscore.query() as result:
386 rows = list(result)
387 self.assertEqual(len(rows), 5)
389 # drop whole run collection
390 registry.removeCollection("run6")
391 with obscore.query() as result:
392 rows = list(result)
393 self.assertEqual(len(rows), 4)
395 def test_associate(self):
396 """Test for associating datasets to TAGGED collection."""
397 collections = ["tagged"]
398 registry = self.make_registry(collections, "TAGGED")
399 obscore = registry.obsCoreTableManager
400 assert obscore is not None
401 refs = self._insert_datasets(registry)
403 with obscore.query() as result:
404 rows = list(result)
405 self.assertEqual(len(rows), 0)
407 # Associate datasets that are already in obscore, changes nothing.
408 registry.associate("tagged", (ref for ref in refs if ref.run == "run1"))
409 with obscore.query() as result:
410 rows = list(result)
411 self.assertEqual(len(rows), 1)
413 # Associate datasets that are not in obscore
414 registry.associate("tagged", (ref for ref in refs if ref.run == "run3"))
415 with obscore.query() as result:
416 rows = list(result)
417 self.assertEqual(len(rows), 2)
419 # Disassociate them
420 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run3"))
421 with obscore.query() as result:
422 rows = list(result)
423 self.assertEqual(len(rows), 1)
425 # Non-associated dataset, should be OK and not throw.
426 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run2"))
427 with obscore.query() as result:
428 rows = list(result)
429 self.assertEqual(len(rows), 1)
431 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run1"))
432 with obscore.query() as result:
433 rows = list(result)
434 self.assertEqual(len(rows), 0)
436 def test_region_type_warning(self) -> None:
437 """Test that non-polygon region generates one or more warnings."""
438 collections = None
439 registry = self.make_registry(collections)
441 with warnings.catch_warnings(record=True) as warning_records:
442 self._insert_dataset(registry, "run2", "calexp", detector=2, visit=9)
443 self.assertEqual(len(warning_records), 1)
444 for record in warning_records:
445 self.assertRegex(
446 str(record.message),
447 "Unexpected region type: .*lsst.sphgeom._sphgeom.Box.*",
448 )
450 def test_update_exposure_region(self) -> None:
451 """Test for update_exposure_regions method."""
452 registry = self.make_registry(["run1"])
453 obscore = registry.obsCoreTableManager
454 assert obscore is not None
456 # Exposure 4 is not associated with any visit.
457 for detector in (1, 2, 3, 4):
458 self._insert_dataset(registry, "run1", "raw", detector=detector, exposure=4)
460 # All spatial columns should be None.
461 with obscore.query() as result:
462 rows = list(result)
463 self.assertEqual(len(rows), 4)
464 for row in rows:
465 self.assertIsNone(row.s_ra)
466 self.assertIsNone(row.s_dec)
467 self.assertIsNone(row.s_region)
469 # Assign Region from visit 4.
470 count = obscore.update_exposure_regions(
471 "DummyCam", [(4, 1, self.regions[(4, 1)]), (4, 2, self.regions[(4, 2)])]
472 )
473 self.assertEqual(count, 2)
475 with obscore.query(["s_ra", "s_dec", "s_region", "lsst_detector"]) as result:
476 rows = list(result)
477 self.assertEqual(len(rows), 4)
478 for row in rows:
479 if row.lsst_detector in (1, 2):
480 self.assertIsNotNone(row.s_ra)
481 self.assertIsNotNone(row.s_dec)
482 self.assertIsNotNone(row.s_region)
483 else:
484 self.assertIsNone(row.s_ra)
485 self.assertIsNone(row.s_dec)
486 self.assertIsNone(row.s_region)
489class SQLiteObsCoreTest(ObsCoreTests, unittest.TestCase):
490 """Unit test for obscore with SQLite backend."""
492 def setUp(self):
493 self.root = makeTestTempDir(TESTDIR)
495 def tearDown(self):
496 removeTestTempDir(self.root)
498 def make_registry_config(
499 self, collections: list[str] | None = None, collection_type: str | None = None
500 ) -> RegistryConfig:
501 # docstring inherited from a base class
502 _, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3")
503 config = RegistryConfig()
504 config["db"] = f"sqlite:///{filename}"
505 config["managers", "obscore"] = {
506 "cls": "lsst.daf.butler.registry.obscore.ObsCoreLiveTableManager",
507 "config": self.make_obscore_config(collections, collection_type),
508 }
509 return config
512@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
513class PostgresObsCoreTest(ObsCoreTests, unittest.TestCase):
514 """Unit test for obscore with PostgreSQL backend."""
516 @classmethod
517 def _handler(cls, postgresql):
518 engine = sqlalchemy.engine.create_engine(postgresql.url())
519 with engine.begin() as connection:
520 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;"))
522 @classmethod
523 def setUpClass(cls):
524 # Create the postgres test server.
525 cls.postgresql = testing.postgresql.PostgresqlFactory(
526 cache_initialized_db=True, on_initialized=cls._handler
527 )
528 super().setUpClass()
530 @classmethod
531 def tearDownClass(cls):
532 # Clean up any lingering SQLAlchemy engines/connections
533 # so they're closed before we shut down the server.
534 gc.collect()
535 cls.postgresql.clear_cache()
536 super().tearDownClass()
538 def setUp(self):
539 self.root = makeTestTempDir(TESTDIR)
540 self.server = self.postgresql()
541 self.count = 0
543 def tearDown(self):
544 removeTestTempDir(self.root)
545 self.server = self.postgresql()
547 def make_registry_config(
548 self, collections: list[str] | None = None, collection_type: str | None = None
549 ) -> RegistryConfig:
550 # docstring inherited from a base class
551 self.count += 1
552 config = RegistryConfig()
553 config["db"] = self.server.url()
554 # Use unique namespace for each instance, some tests may use sub-tests.
555 config["namespace"] = f"namespace{self.count}"
556 config["managers", "obscore"] = {
557 "cls": "lsst.daf.butler.registry.obscore.ObsCoreLiveTableManager",
558 "config": self.make_obscore_config(collections, collection_type),
559 }
560 return config
563@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
564class PostgresPgSphereObsCoreTest(PostgresObsCoreTest):
565 """Unit test for obscore with PostgreSQL backend and pgsphere plugin."""
567 @classmethod
568 def _handler(cls, postgresql):
569 super()._handler(postgresql)
570 engine = sqlalchemy.engine.create_engine(postgresql.url())
571 with engine.begin() as connection:
572 try:
573 connection.execute(sqlalchemy.text("CREATE EXTENSION pg_sphere"))
574 except sqlalchemy.exc.DatabaseError as exc:
575 raise unittest.SkipTest(f"pg_sphere extension does not exist: {exc}") from None
577 def make_obscore_config(
578 self, collections: list[str] | None = None, collection_type: str | None = None
579 ) -> Config:
580 """Make configuration for obscore manager."""
581 obscore_config = super().make_obscore_config(collections, collection_type)
582 obscore_config["spatial_plugins"] = {
583 "pgsphere": {
584 "cls": "lsst.daf.butler.registry.obscore.pgsphere.PgSphereObsCorePlugin",
585 "config": {
586 "region_column": "pgs_region",
587 "position_column": "pgs_center",
588 },
589 }
590 }
591 return obscore_config
593 def test_spatial(self):
594 """Test that pgsphere plugin fills spatial columns."""
595 collections = None
596 registry = self.make_registry(collections)
597 obscore = registry.obsCoreTableManager
598 assert obscore is not None
599 self._insert_datasets(registry)
601 # select everything
602 with obscore.query() as result:
603 rows = list(result)
604 self.assertEqual(len(rows), 6)
606 db = cast(SqlRegistry, registry)._db
607 assert registry.obsCoreTableManager is not None
608 table = cast(ObsCoreLiveTableManager, registry.obsCoreTableManager).table
610 # It's not easy to generate spatial queries in sqlalchemy, use plain
611 # text queries for testing.
613 # position matching visit=1, there is a single dataset
614 query = f"SELECT * FROM {table.key} WHERE pgs_center <-> '(2d,0d)'::spoint < .1"
615 with db.query(sqlalchemy.text(query)) as results:
616 self.assertEqual(len(list(results)), 1)
618 # position matching visit=4, there are two datasets
619 query = f"SELECT * FROM {table.key} WHERE pgs_center <-> '(272d,0d)'::spoint < .1"
620 with db.query(sqlalchemy.text(query)) as results:
621 self.assertEqual(len(list(results)), 2)
623 # position matching visit=1, there is a single dataset
624 query = f"SELECT * FROM {table.key} WHERE '(2d,-3d)'::spoint @ pgs_region"
625 with db.query(sqlalchemy.text(query)) as results:
626 self.assertEqual(len(list(results)), 1)
628 # position matching visit=4, there are two datasets
629 query = f"SELECT * FROM {table.key} WHERE '(272d,3d)'::spoint @ pgs_region"
630 with db.query(sqlalchemy.text(query)) as results:
631 self.assertEqual(len(list(results)), 2)
634if __name__ == "__main__":
635 unittest.main()