Coverage for tests/test_obscore.py: 18%

315 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-15 02:03 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import gc 

29import os 

30import tempfile 

31import unittest 

32import warnings 

33from abc import abstractmethod 

34from typing import cast 

35 

36import astropy.time 

37import sqlalchemy 

38from lsst.daf.butler import ( 

39 CollectionType, 

40 Config, 

41 DataCoordinate, 

42 DatasetRef, 

43 DatasetType, 

44 StorageClassFactory, 

45) 

46from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

47from lsst.daf.butler.registry.obscore import ( 

48 DatasetTypeConfig, 

49 ObsCoreConfig, 

50 ObsCoreLiveTableManager, 

51 ObsCoreSchema, 

52) 

53from lsst.daf.butler.registry.obscore._schema import _STATIC_COLUMNS 

54from lsst.daf.butler.registry.sql_registry import SqlRegistry 

55from lsst.daf.butler.tests.utils import TestCaseMixin, makeTestTempDir, removeTestTempDir 

56from lsst.sphgeom import Box, ConvexPolygon, LonLat, UnitVector3d 

57 

58try: 

59 import testing.postgresql # type: ignore 

60except ImportError: 

61 testing = None 

62 

63TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

64 

65 

66class ObsCoreTests(TestCaseMixin): 

67 """Base class for testing obscore manager functionality.""" 

68 

69 root: str 

70 

71 def make_registry( 

72 self, collections: list[str] | None = None, collection_type: str | None = None 

73 ) -> SqlRegistry: 

74 """Create new empty Registry.""" 

75 config = self.make_registry_config(collections, collection_type) 

76 registry = _RegistryFactory(config).create_from_config(butlerRoot=self.root) 

77 self.initialize_registry(registry) 

78 return registry 

79 

80 @abstractmethod 

81 def make_registry_config( 

82 self, collections: list[str] | None = None, collection_type: str | None = None 

83 ) -> RegistryConfig: 

84 """Make Registry configuration.""" 

85 raise NotImplementedError() 

86 

87 def initialize_registry(self, registry: SqlRegistry) -> None: 

88 """Populate Registry with the things that we need for tests.""" 

89 registry.insertDimensionData("instrument", {"name": "DummyCam"}) 

90 registry.insertDimensionData( 

91 "physical_filter", {"instrument": "DummyCam", "name": "d-r", "band": "r"} 

92 ) 

93 registry.insertDimensionData("day_obs", {"instrument": "DummyCam", "id": 20200101}) 

94 for detector in (1, 2, 3, 4): 

95 registry.insertDimensionData( 

96 "detector", {"instrument": "DummyCam", "id": detector, "full_name": f"detector{detector}"} 

97 ) 

98 

99 for exposure in (1, 2, 3, 4): 

100 registry.insertDimensionData("group", {"instrument": "DummyCam", "name": f"group{exposure}"}) 

101 registry.insertDimensionData( 

102 "exposure", 

103 { 

104 "instrument": "DummyCam", 

105 "id": exposure, 

106 "obs_id": f"exposure{exposure}", 

107 "physical_filter": "d-r", 

108 "group": f"group{exposure}", 

109 "day_obs": 20200101, 

110 }, 

111 ) 

112 

113 registry.insertDimensionData("visit_system", {"instrument": "DummyCam", "id": 1, "name": "default"}) 

114 

115 for visit in (1, 2, 3, 4, 9): 

116 visit_start = astropy.time.Time(f"2020-01-01 08:0{visit}:00", scale="tai") 

117 visit_end = astropy.time.Time(f"2020-01-01 08:0{visit}:45", scale="tai") 

118 registry.insertDimensionData( 

119 "visit", 

120 { 

121 "instrument": "DummyCam", 

122 "id": visit, 

123 "name": f"visit{visit}", 

124 "physical_filter": "d-r", 

125 "datetime_begin": visit_start, 

126 "datetime_end": visit_end, 

127 "day_obs": 20200101, 

128 }, 

129 ) 

130 registry.insertDimensionData( 

131 "visit_system_membership", 

132 {"instrument": "DummyCam", "visit": visit, "visit_system": 1}, 

133 ) 

134 

135 # Only couple of exposures are linked to visits. 

136 for visit in (1, 2): 

137 registry.insertDimensionData( 

138 "visit_definition", 

139 { 

140 "instrument": "DummyCam", 

141 "exposure": visit, 

142 "visit": visit, 

143 }, 

144 ) 

145 

146 # map visit and detector to region 

147 self.regions: dict[tuple[int, int], ConvexPolygon] = {} 

148 for visit in (1, 2, 3, 4): 

149 for detector in (1, 2, 3, 4): 

150 lon = visit * 90 - 88 

151 lat = detector * 2 - 5 

152 region = ConvexPolygon( 

153 [ 

154 UnitVector3d(LonLat.fromDegrees(lon - 1.0, lat - 1.0)), 

155 UnitVector3d(LonLat.fromDegrees(lon + 1.0, lat - 1.0)), 

156 UnitVector3d(LonLat.fromDegrees(lon + 1.0, lat + 1.0)), 

157 UnitVector3d(LonLat.fromDegrees(lon - 1.0, lat + 1.0)), 

158 ] 

159 ) 

160 registry.insertDimensionData( 

161 "visit_detector_region", 

162 { 

163 "instrument": "DummyCam", 

164 "visit": visit, 

165 "detector": detector, 

166 "region": region, 

167 }, 

168 ) 

169 self.regions[(visit, detector)] = region 

170 

171 # Visit 9 has non-polygon region 

172 for detector in (1, 2, 3, 4): 

173 lat = detector * 2 - 5 

174 region = Box.fromDegrees(17.0, lat - 1.0, 19.0, lat + 1.0) 

175 registry.insertDimensionData( 

176 "visit_detector_region", 

177 { 

178 "instrument": "DummyCam", 

179 "visit": 9, 

180 "detector": detector, 

181 "region": region, 

182 }, 

183 ) 

184 

185 # Add few dataset types 

186 storage_class_factory = StorageClassFactory() 

187 storage_class = storage_class_factory.getStorageClass("StructuredDataDict") 

188 

189 self.dataset_types: dict[str, DatasetType] = {} 

190 

191 dimensions = registry.dimensions.conform(["instrument", "physical_filter", "detector", "exposure"]) 

192 self.dataset_types["raw"] = DatasetType("raw", dimensions, storage_class) 

193 

194 dimensions = registry.dimensions.conform(["instrument", "physical_filter", "detector", "visit"]) 

195 self.dataset_types["calexp"] = DatasetType("calexp", dimensions, storage_class) 

196 

197 dimensions = registry.dimensions.conform(["instrument", "physical_filter", "detector", "visit"]) 

198 self.dataset_types["no_obscore"] = DatasetType("no_obscore", dimensions, storage_class) 

199 

200 dimensions = registry.dimensions.conform(["instrument", "physical_filter", "detector"]) 

201 self.dataset_types["calib"] = DatasetType("calib", dimensions, storage_class, isCalibration=True) 

202 

203 for dataset_type in self.dataset_types.values(): 

204 registry.registerDatasetType(dataset_type) 

205 

206 # Add few run collections. 

207 for run in (1, 2, 3, 4, 5, 6): 

208 registry.registerRun(f"run{run}") 

209 

210 # Add few chained collections, run6 is not in any chained collections. 

211 registry.registerCollection("chain12", CollectionType.CHAINED) 

212 registry.setCollectionChain("chain12", ("run1", "run2")) 

213 registry.registerCollection("chain34", CollectionType.CHAINED) 

214 registry.setCollectionChain("chain34", ("run3", "run4")) 

215 registry.registerCollection("chain-all", CollectionType.CHAINED) 

216 registry.setCollectionChain("chain-all", ("chain12", "chain34", "run5")) 

217 

218 # And a tagged collection 

219 registry.registerCollection("tagged", CollectionType.TAGGED) 

220 

221 def make_obscore_config( 

222 self, collections: list[str] | None = None, collection_type: str | None = None 

223 ) -> Config: 

224 """Make configuration for obscore manager.""" 

225 obscore_config = Config(os.path.join(TESTDIR, "config", "basic", "obscore.yaml")) 

226 if collections is not None: 

227 obscore_config["collections"] = collections 

228 if collection_type is not None: 

229 obscore_config["collection_type"] = collection_type 

230 return obscore_config 

231 

232 def _insert_dataset( 

233 self, registry: SqlRegistry, run: str, dataset_type: str, do_import: bool = False, **kwargs 

234 ) -> DatasetRef: 

235 """Insert or import one dataset into a specified run collection.""" 

236 data_id = {"instrument": "DummyCam", "physical_filter": "d-r"} 

237 data_id.update(kwargs) 

238 coordinate = DataCoordinate.standardize(data_id, universe=registry.dimensions) 

239 if do_import: 

240 ds_type = self.dataset_types[dataset_type] 

241 ref = DatasetRef(ds_type, coordinate, run=run) 

242 [ref] = registry._importDatasets([ref]) 

243 else: 

244 [ref] = registry.insertDatasets(dataset_type, [data_id], run=run) 

245 return ref 

246 

247 def _insert_datasets(self, registry: SqlRegistry, do_import: bool = False) -> list[DatasetRef]: 

248 """Inset a small bunch of datasets into every run collection.""" 

249 return [ 

250 self._insert_dataset(registry, "run1", "raw", detector=1, exposure=1, do_import=do_import), 

251 self._insert_dataset(registry, "run2", "calexp", detector=2, visit=2, do_import=do_import), 

252 self._insert_dataset(registry, "run3", "raw", detector=3, exposure=3, do_import=do_import), 

253 self._insert_dataset(registry, "run4", "calexp", detector=4, visit=4, do_import=do_import), 

254 self._insert_dataset(registry, "run5", "calexp", detector=4, visit=4, do_import=do_import), 

255 # This dataset type is not configured, will not be in obscore. 

256 self._insert_dataset(registry, "run5", "no_obscore", detector=1, visit=1, do_import=do_import), 

257 self._insert_dataset(registry, "run6", "raw", detector=1, exposure=4, do_import=do_import), 

258 ] 

259 

260 def test_config_errors(self): 

261 """Test for handling various configuration problems.""" 

262 # This raises pydantic ValidationError, which wraps ValueError 

263 exception_re = "'collections' must have one element" 

264 with self.assertRaisesRegex(ValueError, exception_re): 

265 self.make_registry(None, "TAGGED") 

266 

267 with self.assertRaisesRegex(ValueError, exception_re): 

268 self.make_registry([], "TAGGED") 

269 

270 with self.assertRaisesRegex(ValueError, exception_re): 

271 self.make_registry(["run1", "run2"], "TAGGED") 

272 

273 # Invalid regex. 

274 with self.assertRaisesRegex(ValueError, "Failed to compile regex"): 

275 self.make_registry(["+run"], "RUN") 

276 

277 def test_schema(self): 

278 """Check how obscore schema is constructed""" 

279 config = ObsCoreConfig(obs_collection="", dataset_types={}, facility_name="FACILITY") 

280 schema = ObsCoreSchema(config, []) 

281 table_spec = schema.table_spec 

282 self.assertEqual(list(table_spec.fields.names), [col.name for col in _STATIC_COLUMNS]) 

283 

284 # extra columns from top-level config 

285 config = ObsCoreConfig( 

286 obs_collection="", 

287 extra_columns={"c1": 1, "c2": "string", "c3": {"template": "{calib_level}", "type": "float"}}, 

288 dataset_types={}, 

289 facility_name="FACILITY", 

290 ) 

291 schema = ObsCoreSchema(config, []) 

292 table_spec = schema.table_spec 

293 self.assertEqual( 

294 list(table_spec.fields.names), 

295 [col.name for col in _STATIC_COLUMNS] + ["c1", "c2", "c3"], 

296 ) 

297 self.assertEqual(table_spec.fields["c1"].dtype, sqlalchemy.BigInteger) 

298 self.assertEqual(table_spec.fields["c2"].dtype, sqlalchemy.String) 

299 self.assertEqual(table_spec.fields["c3"].dtype, sqlalchemy.Float) 

300 

301 # extra columns from per-dataset type configs 

302 config = ObsCoreConfig( 

303 obs_collection="", 

304 extra_columns={"c1": 1}, 

305 dataset_types={ 

306 "raw": DatasetTypeConfig( 

307 name="raw", 

308 dataproduct_type="image", 

309 calib_level=1, 

310 extra_columns={"c2": "string"}, 

311 ), 

312 "calexp": DatasetTypeConfig( 

313 dataproduct_type="image", 

314 calib_level=2, 

315 extra_columns={"c3": 1e10}, 

316 ), 

317 }, 

318 facility_name="FACILITY", 

319 ) 

320 schema = ObsCoreSchema(config, []) 

321 table_spec = schema.table_spec 

322 self.assertEqual( 

323 list(table_spec.fields.names), 

324 [col.name for col in _STATIC_COLUMNS] + ["c1", "c2", "c3"], 

325 ) 

326 self.assertEqual(table_spec.fields["c1"].dtype, sqlalchemy.BigInteger) 

327 self.assertEqual(table_spec.fields["c2"].dtype, sqlalchemy.String) 

328 self.assertEqual(table_spec.fields["c3"].dtype, sqlalchemy.Float) 

329 

330 # Columns with the same names as in static list in configs, types 

331 # are not overriden. 

332 config = ObsCoreConfig( 

333 version=0, 

334 obs_collection="", 

335 extra_columns={"t_xel": 1e10}, 

336 dataset_types={ 

337 "raw": DatasetTypeConfig( 

338 dataproduct_type="image", 

339 calib_level=1, 

340 extra_columns={"target_name": 1}, 

341 ), 

342 "calexp": DatasetTypeConfig( 

343 dataproduct_type="image", 

344 calib_level=2, 

345 extra_columns={"em_xel": "string"}, 

346 ), 

347 }, 

348 facility_name="FACILITY", 

349 ) 

350 schema = ObsCoreSchema(config, []) 

351 table_spec = schema.table_spec 

352 self.assertEqual(list(table_spec.fields.names), [col.name for col in _STATIC_COLUMNS]) 

353 self.assertEqual(table_spec.fields["t_xel"].dtype, sqlalchemy.Integer) 

354 self.assertEqual(table_spec.fields["target_name"].dtype, sqlalchemy.String) 

355 self.assertEqual(table_spec.fields["em_xel"].dtype, sqlalchemy.Integer) 

356 

357 def test_insert_existing_collection(self): 

358 """Test insert and import registry methods, with various restrictions 

359 on collection names. 

360 """ 

361 # First item is collections, second item is expected record count. 

362 test_data = ( 

363 (None, 6), 

364 (["run1", "run2"], 2), 

365 (["run[34]"], 2), 

366 (["[rR]un[^6]"], 5), 

367 ) 

368 

369 for collections, count in test_data: 

370 for do_import in (False, True): 

371 registry = self.make_registry(collections) 

372 obscore = registry.obsCoreTableManager 

373 assert obscore is not None 

374 self._insert_datasets(registry, do_import) 

375 

376 with obscore.query() as result: 

377 rows = list(result) 

378 self.assertEqual(len(rows), count) 

379 

380 # Also check `query` method with COUNT(*) 

381 with obscore.query([sqlalchemy.sql.func.count()]) as result: 

382 scalar = result.scalar_one() 

383 self.assertEqual(scalar, count) 

384 

385 def test_drop_datasets(self): 

386 """Test for dropping datasets after obscore insert.""" 

387 collections = None 

388 registry = self.make_registry(collections) 

389 obscore = registry.obsCoreTableManager 

390 assert obscore is not None 

391 refs = self._insert_datasets(registry) 

392 

393 with obscore.query() as result: 

394 rows = list(result) 

395 self.assertEqual(len(rows), 6) 

396 

397 # drop single dataset 

398 registry.removeDatasets(ref for ref in refs if ref.run == "run1") 

399 with obscore.query() as result: 

400 rows = list(result) 

401 self.assertEqual(len(rows), 5) 

402 

403 # drop whole run collection 

404 registry.removeCollection("run6") 

405 with obscore.query() as result: 

406 rows = list(result) 

407 self.assertEqual(len(rows), 4) 

408 

409 def test_associate(self): 

410 """Test for associating datasets to TAGGED collection.""" 

411 collections = ["tagged"] 

412 registry = self.make_registry(collections, "TAGGED") 

413 obscore = registry.obsCoreTableManager 

414 assert obscore is not None 

415 refs = self._insert_datasets(registry) 

416 

417 with obscore.query() as result: 

418 rows = list(result) 

419 self.assertEqual(len(rows), 0) 

420 

421 # Associate datasets that are already in obscore, changes nothing. 

422 registry.associate("tagged", (ref for ref in refs if ref.run == "run1")) 

423 with obscore.query() as result: 

424 rows = list(result) 

425 self.assertEqual(len(rows), 1) 

426 

427 # Associate datasets that are not in obscore 

428 registry.associate("tagged", (ref for ref in refs if ref.run == "run3")) 

429 with obscore.query() as result: 

430 rows = list(result) 

431 self.assertEqual(len(rows), 2) 

432 

433 # Disassociate them 

434 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run3")) 

435 with obscore.query() as result: 

436 rows = list(result) 

437 self.assertEqual(len(rows), 1) 

438 

439 # Non-associated dataset, should be OK and not throw. 

440 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run2")) 

441 with obscore.query() as result: 

442 rows = list(result) 

443 self.assertEqual(len(rows), 1) 

444 

445 registry.disassociate("tagged", (ref for ref in refs if ref.run == "run1")) 

446 with obscore.query() as result: 

447 rows = list(result) 

448 self.assertEqual(len(rows), 0) 

449 

450 @unittest.skip("Temporary, while deprecation warnings are present.") 

451 def test_region_type_warning(self) -> None: 

452 """Test that non-polygon region generates one or more warnings.""" 

453 collections = None 

454 registry = self.make_registry(collections) 

455 

456 with warnings.catch_warnings(record=True) as warning_records: 

457 self._insert_dataset(registry, "run2", "calexp", detector=2, visit=9) 

458 self.assertEqual(len(warning_records), 1) 

459 for record in warning_records: 

460 self.assertRegex( 

461 str(record.message), 

462 "Unexpected region type: .*lsst.sphgeom._sphgeom.Box.*", 

463 ) 

464 

465 def test_update_exposure_region(self) -> None: 

466 """Test for update_exposure_regions method.""" 

467 registry = self.make_registry(["run1"]) 

468 obscore = registry.obsCoreTableManager 

469 assert obscore is not None 

470 

471 # Exposure 4 is not associated with any visit. 

472 for detector in (1, 2, 3, 4): 

473 self._insert_dataset(registry, "run1", "raw", detector=detector, exposure=4) 

474 

475 # All spatial columns should be None. 

476 with obscore.query() as result: 

477 rows = list(result) 

478 self.assertEqual(len(rows), 4) 

479 for row in rows: 

480 self.assertIsNone(row.s_ra) 

481 self.assertIsNone(row.s_dec) 

482 self.assertIsNone(row.s_region) 

483 

484 # Assign Region from visit 4. 

485 count = obscore.update_exposure_regions( 

486 "DummyCam", [(4, 1, self.regions[(4, 1)]), (4, 2, self.regions[(4, 2)])] 

487 ) 

488 self.assertEqual(count, 2) 

489 

490 with obscore.query(["s_ra", "s_dec", "s_region", "lsst_detector"]) as result: 

491 rows = list(result) 

492 self.assertEqual(len(rows), 4) 

493 for row in rows: 

494 if row.lsst_detector in (1, 2): 

495 self.assertIsNotNone(row.s_ra) 

496 self.assertIsNotNone(row.s_dec) 

497 self.assertIsNotNone(row.s_region) 

498 else: 

499 self.assertIsNone(row.s_ra) 

500 self.assertIsNone(row.s_dec) 

501 self.assertIsNone(row.s_region) 

502 

503 

504class SQLiteObsCoreTest(ObsCoreTests, unittest.TestCase): 

505 """Unit test for obscore with SQLite backend.""" 

506 

507 def setUp(self): 

508 self.root = makeTestTempDir(TESTDIR) 

509 

510 def tearDown(self): 

511 removeTestTempDir(self.root) 

512 

513 def make_registry_config( 

514 self, collections: list[str] | None = None, collection_type: str | None = None 

515 ) -> RegistryConfig: 

516 # docstring inherited from a base class 

517 _, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3") 

518 config = RegistryConfig() 

519 config["db"] = f"sqlite:///{filename}" 

520 config["managers", "obscore"] = { 

521 "cls": "lsst.daf.butler.registry.obscore.ObsCoreLiveTableManager", 

522 "config": self.make_obscore_config(collections, collection_type), 

523 } 

524 return config 

525 

526 

527class ClonedSqliteObscoreTest(SQLiteObsCoreTest, unittest.TestCase): 

528 """Unit test for obscore manager created via clone()""" 

529 

530 def make_registry( 

531 self, collections: list[str] | None = None, collection_type: str | None = None 

532 ) -> SqlRegistry: 

533 """Create new empty Registry.""" 

534 original = super().make_registry(collections, collection_type) 

535 return original.copy() 

536 

537 

538@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

539class PostgresObsCoreTest(ObsCoreTests, unittest.TestCase): 

540 """Unit test for obscore with PostgreSQL backend.""" 

541 

542 @classmethod 

543 def _handler(cls, postgresql): 

544 engine = sqlalchemy.engine.create_engine(postgresql.url()) 

545 with engine.begin() as connection: 

546 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;")) 

547 

548 @classmethod 

549 def setUpClass(cls): 

550 # Create the postgres test server. 

551 cls.postgresql = testing.postgresql.PostgresqlFactory( 

552 cache_initialized_db=True, on_initialized=cls._handler 

553 ) 

554 super().setUpClass() 

555 

556 @classmethod 

557 def tearDownClass(cls): 

558 # Clean up any lingering SQLAlchemy engines/connections 

559 # so they're closed before we shut down the server. 

560 gc.collect() 

561 cls.postgresql.clear_cache() 

562 super().tearDownClass() 

563 

564 def setUp(self): 

565 self.root = makeTestTempDir(TESTDIR) 

566 self.server = self.postgresql() 

567 self.count = 0 

568 

569 def tearDown(self): 

570 removeTestTempDir(self.root) 

571 self.server = self.postgresql() 

572 

573 def make_registry_config( 

574 self, collections: list[str] | None = None, collection_type: str | None = None 

575 ) -> RegistryConfig: 

576 # docstring inherited from a base class 

577 self.count += 1 

578 config = RegistryConfig() 

579 config["db"] = self.server.url() 

580 # Use unique namespace for each instance, some tests may use sub-tests. 

581 config["namespace"] = f"namespace{self.count}" 

582 config["managers", "obscore"] = { 

583 "cls": "lsst.daf.butler.registry.obscore.ObsCoreLiveTableManager", 

584 "config": self.make_obscore_config(collections, collection_type), 

585 } 

586 return config 

587 

588 

589@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

590class PostgresPgSphereObsCoreTest(PostgresObsCoreTest): 

591 """Unit test for obscore with PostgreSQL backend and pgsphere plugin.""" 

592 

593 @classmethod 

594 def _handler(cls, postgresql): 

595 super()._handler(postgresql) 

596 engine = sqlalchemy.engine.create_engine(postgresql.url()) 

597 with engine.begin() as connection: 

598 try: 

599 connection.execute(sqlalchemy.text("CREATE EXTENSION pg_sphere")) 

600 except sqlalchemy.exc.DatabaseError as exc: 

601 raise unittest.SkipTest(f"pg_sphere extension does not exist: {exc}") from None 

602 

603 def make_obscore_config( 

604 self, collections: list[str] | None = None, collection_type: str | None = None 

605 ) -> Config: 

606 """Make configuration for obscore manager.""" 

607 obscore_config = super().make_obscore_config(collections, collection_type) 

608 obscore_config["spatial_plugins"] = { 

609 "pgsphere": { 

610 "cls": "lsst.daf.butler.registry.obscore.pgsphere.PgSphereObsCorePlugin", 

611 "config": { 

612 "region_column": "pgs_region", 

613 "position_column": "pgs_center", 

614 }, 

615 } 

616 } 

617 return obscore_config 

618 

619 def test_spatial(self): 

620 """Test that pgsphere plugin fills spatial columns.""" 

621 collections = None 

622 registry = self.make_registry(collections) 

623 obscore = registry.obsCoreTableManager 

624 assert obscore is not None 

625 self._insert_datasets(registry) 

626 

627 # select everything 

628 with obscore.query() as result: 

629 rows = list(result) 

630 self.assertEqual(len(rows), 6) 

631 

632 db = registry._db 

633 assert registry.obsCoreTableManager is not None 

634 table = cast(ObsCoreLiveTableManager, registry.obsCoreTableManager).table 

635 

636 # It's not easy to generate spatial queries in sqlalchemy, use plain 

637 # text queries for testing. 

638 

639 # position matching visit=1, there is a single dataset 

640 query = f"SELECT * FROM {table.key} WHERE pgs_center <-> '(2d,0d)'::spoint < .1" 

641 with db.query(sqlalchemy.text(query)) as results: 

642 self.assertEqual(len(list(results)), 1) 

643 

644 # position matching visit=4, there are two datasets 

645 query = f"SELECT * FROM {table.key} WHERE pgs_center <-> '(272d,0d)'::spoint < .1" 

646 with db.query(sqlalchemy.text(query)) as results: 

647 self.assertEqual(len(list(results)), 2) 

648 

649 # position matching visit=1, there is a single dataset 

650 query = f"SELECT * FROM {table.key} WHERE '(2d,-3d)'::spoint @ pgs_region" 

651 with db.query(sqlalchemy.text(query)) as results: 

652 self.assertEqual(len(list(results)), 1) 

653 

654 # position matching visit=4, there are two datasets 

655 query = f"SELECT * FROM {table.key} WHERE '(272d,3d)'::spoint @ pgs_region" 

656 with db.query(sqlalchemy.text(query)) as results: 

657 self.assertEqual(len(list(results)), 2) 

658 

659 

660if __name__ == "__main__": 

661 unittest.main()