Coverage for tests/test_obscore.py: 18%
264 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-04 02:04 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-04 02:04 -0800
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import gc
23import os
24import tempfile
25import unittest
26import warnings
27from abc import abstractmethod
28from typing import Dict, List, Optional
30import astropy.time
31import sqlalchemy
32from lsst.daf.butler import (
33 CollectionType,
34 Config,
35 DatasetIdGenEnum,
36 DatasetRef,
37 DatasetType,
38 StorageClassFactory,
39)
40from lsst.daf.butler.registry import Registry, RegistryConfig
41from lsst.daf.butler.registry.obscore import DatasetTypeConfig, ObsCoreConfig, ObsCoreSchema
42from lsst.daf.butler.registry.obscore._schema import _STATIC_COLUMNS
43from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
44from lsst.sphgeom import Box, ConvexPolygon, LonLat, UnitVector3d
46try:
47 import testing.postgresql
48except ImportError:
49 testing = None
51TESTDIR = os.path.abspath(os.path.dirname(__file__))
54class ObsCoreTests:
55 """Base class for testing obscore manager functionality."""
57 def make_registry(
58 self, collections: Optional[List[str]] = None, collection_type: Optional[str] = None
59 ) -> Registry:
60 """Create new empty Registry."""
61 config = self.make_registry_config(collections, collection_type)
62 registry = Registry.createFromConfig(config, butlerRoot=self.root)
63 self.initialize_registry(registry)
64 return registry
66 @abstractmethod
67 def make_registry_config(
68 self, collections: Optional[List[str]] = None, collection_type: Optional[str] = None
69 ) -> RegistryConfig:
70 """Make Registry configuration."""
71 raise NotImplementedError()
73 def initialize_registry(self, registry: Registry) -> None:
74 """Populate Registry with the things that we need for tests."""
76 registry.insertDimensionData("instrument", {"name": "DummyCam"})
77 registry.insertDimensionData(
78 "physical_filter", {"instrument": "DummyCam", "name": "d-r", "band": "r"}
79 )
80 for detector in (1, 2, 3, 4):
81 registry.insertDimensionData(
82 "detector", {"instrument": "DummyCam", "id": detector, "full_name": f"detector{detector}"}
83 )
85 for exposure in (1, 2, 3, 4):
86 registry.insertDimensionData(
87 "exposure",
88 {
89 "instrument": "DummyCam",
90 "id": exposure,
91 "obs_id": f"exposure{exposure}",
92 "physical_filter": "d-r",
93 },
94 )
96 registry.insertDimensionData("visit_system", {"instrument": "DummyCam", "id": 1, "name": "default"})
98 for visit in (1, 2, 3, 4, 9):
99 visit_start = astropy.time.Time(f"2020-01-01 08:0{visit}:00", scale="tai")
100 visit_end = astropy.time.Time(f"2020-01-01 08:0{visit}:45", scale="tai")
101 registry.insertDimensionData(
102 "visit",
103 {
104 "instrument": "DummyCam",
105 "id": visit,
106 "name": f"visit{visit}",
107 "physical_filter": "d-r",
108 "visit_system": 1,
109 "datetime_begin": visit_start,
110 "datetime_end": visit_end,
111 },
112 )
114 # Only couple of exposures are linked to visits.
115 for visit in (1, 2):
116 registry.insertDimensionData(
117 "visit_definition",
118 {
119 "instrument": "DummyCam",
120 "exposure": visit,
121 "visit": visit,
122 },
123 )
125 for visit in (1, 2, 3, 4):
126 for detector in (1, 2, 3, 4):
127 lon = visit * 90 - 88
128 lat = detector * 2 - 5
129 region = ConvexPolygon(
130 [
131 UnitVector3d(LonLat.fromDegrees(lon - 1.0, lat - 1.0)),
132 UnitVector3d(LonLat.fromDegrees(lon + 1.0, lat - 1.0)),
133 UnitVector3d(LonLat.fromDegrees(lon + 1.0, lat + 1.0)),
134 UnitVector3d(LonLat.fromDegrees(lon - 1.0, lat + 1.0)),
135 ]
136 )
137 registry.insertDimensionData(
138 "visit_detector_region",
139 {
140 "instrument": "DummyCam",
141 "visit": visit,
142 "detector": detector,
143 "region": region,
144 },
145 )
147 # Visit 9 has non-polygon region
148 for detector in (1, 2, 3, 4):
149 lat = detector * 2 - 5
150 region = Box.fromDegrees(17.0, lat - 1.0, 19.0, lat + 1.0)
151 registry.insertDimensionData(
152 "visit_detector_region",
153 {
154 "instrument": "DummyCam",
155 "visit": 9,
156 "detector": detector,
157 "region": region,
158 },
159 )
161 # Add few dataset types
162 storage_class_factory = StorageClassFactory()
163 storage_class = storage_class_factory.getStorageClass("StructuredDataDict")
165 self.dataset_types: Dict[str, DatasetType] = {}
167 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector", "exposure"])
168 self.dataset_types["raw"] = DatasetType("raw", dimensions, storage_class)
170 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector", "visit"])
171 self.dataset_types["calexp"] = DatasetType("calexp", dimensions, storage_class)
173 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector", "visit"])
174 self.dataset_types["no_obscore"] = DatasetType("no_obscore", dimensions, storage_class)
176 dimensions = registry.dimensions.extract(["instrument", "physical_filter", "detector"])
177 self.dataset_types["calib"] = DatasetType("calib", dimensions, storage_class, isCalibration=True)
179 for dataset_type in self.dataset_types.values():
180 registry.registerDatasetType(dataset_type)
182 # Add few run collections.
183 for run in (1, 2, 3, 4, 5, 6):
184 registry.registerRun(f"run{run}")
186 # Add few chained collections, run6 is not in any chained collections.
187 registry.registerCollection("chain12", CollectionType.CHAINED)
188 registry.setCollectionChain("chain12", ("run1", "run2"))
189 registry.registerCollection("chain34", CollectionType.CHAINED)
190 registry.setCollectionChain("chain34", ("run3", "run4"))
191 registry.registerCollection("chain-all", CollectionType.CHAINED)
192 registry.setCollectionChain("chain-all", ("chain12", "chain34", "run5"))
194 # And a tagged collection
195 registry.registerCollection("tagged", CollectionType.TAGGED)
197 def make_obscore_config(
198 self, collections: Optional[List[str]] = None, collection_type: Optional[str] = None
199 ) -> Config:
200 """Make configuration for obscore manager."""
201 obscore_config = Config(os.path.join(TESTDIR, "config", "basic", "obscore.yaml"))
202 if collections is not None:
203 obscore_config["collections"] = collections
204 if collection_type is not None:
205 obscore_config["collection_type"] = collection_type
206 return obscore_config
208 def _insert_dataset(
209 self, registry: Registry, run: str, dataset_type: str, do_import: bool = False, **kwargs
210 ) -> DatasetRef:
211 """Insert or import one dataset into a specified run collection."""
212 data_id = {"instrument": "DummyCam", "physical_filter": "d-r"}
213 data_id.update(kwargs)
214 if do_import:
215 ds_type = self.dataset_types[dataset_type]
216 dataset_id = registry.datasetIdFactory.makeDatasetId(
217 run, ds_type, data_id, DatasetIdGenEnum.UNIQUE
218 )
219 ref = DatasetRef(ds_type, data_id, id=dataset_id, run=run)
220 [ref] = registry._importDatasets([ref])
221 else:
222 [ref] = registry.insertDatasets(dataset_type, [data_id], run=run)
223 return ref
225 def _insert_datasets(self, registry: Registry, do_import: bool = False) -> List[DatasetRef]:
226 """Inset a small bunch of datasets into every run collection."""
227 return [
228 self._insert_dataset(registry, "run1", "raw", detector=1, exposure=1, do_import=do_import),
229 self._insert_dataset(registry, "run2", "calexp", detector=2, visit=2, do_import=do_import),
230 self._insert_dataset(registry, "run3", "raw", detector=3, exposure=3, do_import=do_import),
231 self._insert_dataset(registry, "run4", "calexp", detector=4, visit=4, do_import=do_import),
232 self._insert_dataset(registry, "run5", "calexp", detector=4, visit=4, do_import=do_import),
233 # This dataset type is not configured, will not be in obscore.
234 self._insert_dataset(registry, "run5", "no_obscore", detector=1, visit=1, do_import=do_import),
235 self._insert_dataset(registry, "run6", "raw", detector=1, exposure=4, do_import=do_import),
236 ]
238 def _obscore_select(self, registry: Registry) -> list:
239 """Select all rows from obscore table."""
240 db = registry._db
241 table = registry._managers.obscore.table
242 with db.query(table.select()) as results:
243 return results.fetchall()
245 def test_config_errors(self):
246 """Test for handling various configuration problems."""
248 # This raises pydantic ValidationError, which wraps ValueError
249 exception_re = "'collections' must have one element"
250 with self.assertRaisesRegex(ValueError, exception_re):
251 self.make_registry(None, "TAGGED")
253 with self.assertRaisesRegex(ValueError, exception_re):
254 self.make_registry([], "TAGGED")
256 with self.assertRaisesRegex(ValueError, exception_re):
257 self.make_registry(["run1", "run2"], "TAGGED")
259 # Invalid regex.
260 with self.assertRaisesRegex(ValueError, "Failed to compile regex"):
261 self.make_registry(["+run"], "RUN")
263 def test_schema(self):
264 """Check how obscore schema is constructed"""
266 config = ObsCoreConfig(obs_collection="", dataset_types=[], facility_name="FACILITY")
267 schema = ObsCoreSchema(config, [])
268 table_spec = schema.table_spec
269 self.assertEqual(list(table_spec.fields.names), [col.name for col in _STATIC_COLUMNS])
271 # extra columns from top-level config
272 config = ObsCoreConfig(
273 obs_collection="",
274 extra_columns={"c1": 1, "c2": "string", "c3": {"template": "{calib_level}", "type": "float"}},
275 dataset_types=[],
276 facility_name="FACILITY",
277 )
278 schema = ObsCoreSchema(config, [])
279 table_spec = schema.table_spec
280 self.assertEqual(
281 list(table_spec.fields.names),
282 [col.name for col in _STATIC_COLUMNS] + ["c1", "c2", "c3"],
283 )
284 self.assertEqual(table_spec.fields["c1"].dtype, sqlalchemy.BigInteger)
285 self.assertEqual(table_spec.fields["c2"].dtype, sqlalchemy.String)
286 self.assertEqual(table_spec.fields["c3"].dtype, sqlalchemy.Float)
288 # extra columns from per-dataset type configs
289 config = ObsCoreConfig(
290 obs_collection="",
291 extra_columns={"c1": 1},
292 dataset_types={
293 "raw": DatasetTypeConfig(
294 name="raw",
295 dataproduct_type="image",
296 calib_level=1,
297 extra_columns={"c2": "string"},
298 ),
299 "calexp": DatasetTypeConfig(
300 dataproduct_type="image",
301 calib_level=2,
302 extra_columns={"c3": 1e10},
303 ),
304 },
305 facility_name="FACILITY",
306 )
307 schema = ObsCoreSchema(config, [])
308 table_spec = schema.table_spec
309 self.assertEqual(
310 list(table_spec.fields.names),
311 [col.name for col in _STATIC_COLUMNS] + ["c1", "c2", "c3"],
312 )
313 self.assertEqual(table_spec.fields["c1"].dtype, sqlalchemy.BigInteger)
314 self.assertEqual(table_spec.fields["c2"].dtype, sqlalchemy.String)
315 self.assertEqual(table_spec.fields["c3"].dtype, sqlalchemy.Float)
317 # Columns with the same names as in static list in configs, types
318 # are not overriden.
319 config = ObsCoreConfig(
320 version=0,
321 obs_collection="",
322 extra_columns={"t_xel": 1e10},
323 dataset_types={
324 "raw": DatasetTypeConfig(
325 dataproduct_type="image",
326 calib_level=1,
327 extra_columns={"target_name": 1},
328 ),
329 "calexp": DatasetTypeConfig(
330 dataproduct_type="image",
331 calib_level=2,
332 extra_columns={"em_xel": "string"},
333 ),
334 },
335 facility_name="FACILITY",
336 )
337 schema = ObsCoreSchema(config, [])
338 table_spec = schema.table_spec
339 self.assertEqual(list(table_spec.fields.names), [col.name for col in _STATIC_COLUMNS])
340 self.assertEqual(table_spec.fields["t_xel"].dtype, sqlalchemy.Integer)
341 self.assertEqual(table_spec.fields["target_name"].dtype, sqlalchemy.String)
342 self.assertEqual(table_spec.fields["em_xel"].dtype, sqlalchemy.Integer)
344 def test_insert_existing_collection(self):
345 """Test insert and import registry methods, with various restrictions
346 on collection names.
347 """
349 # First item is collections, second item is expected record count.
350 test_data = (
351 (None, 6),
352 (["run1", "run2"], 2),
353 (["run[34]"], 2),
354 (["[rR]un[^6]"], 5),
355 )
357 for collections, count in test_data:
358 for do_import in (False, True):
359 registry = self.make_registry(collections)
360 self._insert_datasets(registry, do_import)
362 rows = self._obscore_select(registry)
363 self.assertEqual(len(rows), count)
365 def test_drop_datasets(self):
366 """Test for dropping datasets after obscore insert."""
368 collections = None
369 registry = self.make_registry(collections)
370 refs = self._insert_datasets(registry)
372 rows = self._obscore_select(registry)
373 self.assertEqual(len(rows), 6)
375 # drop single dataset
376 registry.removeDatasets(ref for ref in refs if ref.run == "run1")
377 rows = self._obscore_select(registry)
378 self.assertEqual(len(rows), 5)
380 # drop whole run collection
381 registry.removeCollection("run6")
382 rows = self._obscore_select(registry)
383 self.assertEqual(len(rows), 4)
385 def test_associate(self):
386 """Test for associating datasets to TAGGED collection."""
388 collections = ["tagged"]
389 registry = self.make_registry(collections, "TAGGED")
390 refs = self._insert_datasets(registry)
392 rows = self._obscore_select(registry)
393 self.assertEqual(len(rows), 0)
395 # Associate datasets that are already in obscore, changes nothing.
396 registry.associate("tagged", (ref for ref in refs if ref.run == "run1"))
397 rows = self._obscore_select(registry)
398 self.assertEqual(len(rows), 1)
400 # Associate datasets that are not in obscore
401 registry.associate("tagged", (ref for ref in refs if ref.run == "run3"))
402 rows = self._obscore_select(registry)
403 self.assertEqual(len(rows), 2)
405 # Disassociate them
406 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run3"))
407 rows = self._obscore_select(registry)
408 self.assertEqual(len(rows), 1)
410 # Non-associated dataset, should be OK and not throw.
411 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run2"))
412 rows = self._obscore_select(registry)
413 self.assertEqual(len(rows), 1)
415 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run1"))
416 rows = self._obscore_select(registry)
417 self.assertEqual(len(rows), 0)
419 def test_region_type_warning(self, count: int = 1) -> None:
420 """Test that non-polygon region generates one or more warnings."""
422 collections = None
423 registry = self.make_registry(collections)
425 with warnings.catch_warnings(record=True) as warning_records:
426 self._insert_dataset(registry, "run2", "calexp", detector=2, visit=9)
427 self.assertEqual(len(warning_records), count)
428 for record in warning_records:
429 self.assertRegex(
430 str(record.message),
431 "Unexpected region type for obscore dataset.*lsst.sphgeom._sphgeom.Box.*",
432 )
435class SQLiteObsCoreTest(ObsCoreTests, unittest.TestCase):
436 """Unit test for obscore with SQLite backend."""
438 def setUp(self):
439 self.root = makeTestTempDir(TESTDIR)
441 def tearDown(self):
442 removeTestTempDir(self.root)
444 def make_registry_config(
445 self, collections: Optional[List[str]] = None, collection_type: Optional[str] = None
446 ) -> RegistryConfig:
447 # docstring inherited from a base class
448 _, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3")
449 config = RegistryConfig()
450 config["db"] = f"sqlite:///{filename}"
451 config["managers", "obscore"] = {
452 "cls": "lsst.daf.butler.registry.obscore.ObsCoreLiveTableManager",
453 "config": self.make_obscore_config(collections, collection_type),
454 }
455 return config
458@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
459class PostgresObsCoreTest(ObsCoreTests, unittest.TestCase):
460 """Unit test for obscore with PostgreSQL backend."""
462 @classmethod
463 def _handler(cls, postgresql):
464 engine = sqlalchemy.engine.create_engine(postgresql.url())
465 with engine.begin() as connection:
466 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;"))
468 @classmethod
469 def setUpClass(cls):
470 # Create the postgres test server.
471 cls.postgresql = testing.postgresql.PostgresqlFactory(
472 cache_initialized_db=True, on_initialized=cls._handler
473 )
474 super().setUpClass()
476 @classmethod
477 def tearDownClass(cls):
478 # Clean up any lingering SQLAlchemy engines/connections
479 # so they're closed before we shut down the server.
480 gc.collect()
481 cls.postgresql.clear_cache()
482 super().tearDownClass()
484 def setUp(self):
485 self.root = makeTestTempDir(TESTDIR)
486 self.server = self.postgresql()
487 self.count = 0
489 def tearDown(self):
490 removeTestTempDir(self.root)
491 self.server = self.postgresql()
493 def make_registry_config(
494 self, collections: Optional[List[str]] = None, collection_type: Optional[str] = None
495 ) -> RegistryConfig:
496 # docstring inherited from a base class
497 self.count += 1
498 config = RegistryConfig()
499 config["db"] = self.server.url()
500 # Use unique namespace for each instance, some tests may use sub-tests.
501 config["namespace"] = f"namespace{self.count}"
502 config["managers", "obscore"] = {
503 "cls": "lsst.daf.butler.registry.obscore.ObsCoreLiveTableManager",
504 "config": self.make_obscore_config(collections, collection_type),
505 }
506 return config
509@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
510class PostgresPgSphereObsCoreTest(PostgresObsCoreTest):
511 """Unit test for obscore with PostgreSQL backend and pgsphere plugin."""
513 @classmethod
514 def _handler(cls, postgresql):
515 super()._handler(postgresql)
516 engine = sqlalchemy.engine.create_engine(postgresql.url())
517 with engine.begin() as connection:
518 try:
519 connection.execute(sqlalchemy.text("CREATE EXTENSION pg_sphere"))
520 except sqlalchemy.exc.DatabaseError as exc:
521 raise unittest.SkipTest(f"pg_sphere extension does not exist: {exc}")
523 def make_obscore_config(
524 self, collections: Optional[List[str]] = None, collection_type: Optional[str] = None
525 ) -> Config:
526 """Make configuration for obscore manager."""
527 obscore_config = super().make_obscore_config(collections, collection_type)
528 obscore_config["spatial_plugins"] = {
529 "pgsphere": {
530 "cls": "lsst.daf.butler.registry.obscore.pgsphere.PgSphereObsCorePlugin",
531 "config": {
532 "region_column": "pgs_region",
533 "position_column": "pgs_center",
534 },
535 }
536 }
537 return obscore_config
539 def test_spatial(self):
540 """Test that pgsphere plugin fills spatial columns."""
542 collections = None
543 registry = self.make_registry(collections)
544 self._insert_datasets(registry)
546 # select everything
547 rows = self._obscore_select(registry)
548 self.assertEqual(len(rows), 6)
550 db = registry._db
551 table = registry._managers.obscore.table
553 # It's not easy to generate spatial queries in sqlalchemy, use plain
554 # text queries for testing.
556 # position matching visit=1, there is a single dataset
557 query = f"SELECT * FROM {table.key} WHERE pgs_center <-> '(2d,0d)'::spoint < .1"
558 with db.query(sqlalchemy.text(query)) as results:
559 self.assertEqual(len(list(results)), 1)
561 # position matching visit=4, there are two datasets
562 query = f"SELECT * FROM {table.key} WHERE pgs_center <-> '(272d,0d)'::spoint < .1"
563 with db.query(sqlalchemy.text(query)) as results:
564 self.assertEqual(len(list(results)), 2)
566 # position matching visit=1, there is a single dataset
567 query = f"SELECT * FROM {table.key} WHERE '(2d,-3d)'::spoint @ pgs_region"
568 with db.query(sqlalchemy.text(query)) as results:
569 self.assertEqual(len(list(results)), 1)
571 # position matching visit=4, there are two datasets
572 query = f"SELECT * FROM {table.key} WHERE '(272d,3d)'::spoint @ pgs_region"
573 with db.query(sqlalchemy.text(query)) as results:
574 self.assertEqual(len(list(results)), 2)
576 def test_region_type_warning(self) -> None:
577 """Test that non-polygon region generates a warning"""
578 # pgsphere plugin adds one more warning
579 super().test_region_type_warning(2)
582if __name__ == "__main__": 582 ↛ 583line 582 didn't jump to line 583, because the condition on line 582 was never true
583 unittest.main()