Coverage for tests/test_dimensions.py: 11%

445 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-01 11:20 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import copy 

29import itertools 

30import math 

31import os 

32import pickle 

33import unittest 

34from collections.abc import Iterator 

35from dataclasses import dataclass 

36from random import Random 

37 

38import lsst.sphgeom 

39from lsst.daf.butler import ( 

40 Config, 

41 DataCoordinate, 

42 DataCoordinateSequence, 

43 DataCoordinateSet, 

44 Dimension, 

45 DimensionConfig, 

46 DimensionElement, 

47 DimensionGroup, 

48 DimensionPacker, 

49 DimensionUniverse, 

50 NamedKeyDict, 

51 NamedValueSet, 

52 YamlRepoImportBackend, 

53 ddl, 

54) 

55from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

56from lsst.daf.butler.timespan_database_representation import TimespanDatabaseRepresentation 

57 

58DIMENSION_DATA_FILE = os.path.normpath( 

59 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

60) 

61 

62 

63def loadDimensionData() -> DataCoordinateSequence: 

64 """Load dimension data from an export file included in the code repository. 

65 

66 Returns 

67 ------- 

68 dataIds : `DataCoordinateSet` 

69 A set containing all data IDs in the export file. 

70 """ 

71 # Create an in-memory SQLite database and Registry just to import the YAML 

72 # data and retreive it as a set of DataCoordinate objects. 

73 config = RegistryConfig() 

74 config["db"] = "sqlite://" 

75 registry = _RegistryFactory(config).create_from_config() 

76 with open(DIMENSION_DATA_FILE) as stream: 

77 backend = YamlRepoImportBackend(stream, registry) 

78 backend.register() 

79 backend.load(datastore=None) 

80 dimensions = registry.dimensions.conform(["visit", "detector", "tract", "patch"]) 

81 return registry.queryDataIds(dimensions).expanded().toSequence() 

82 

83 

84class ConcreteTestDimensionPacker(DimensionPacker): 

85 """A concrete `DimensionPacker` for testing its base class implementations. 

86 

87 This class just returns the detector ID as-is. 

88 """ 

89 

90 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGroup): 

91 super().__init__(fixed, dimensions) 

92 self._n_detectors = fixed.records["instrument"].detector_max 

93 self._max_bits = (self._n_detectors - 1).bit_length() 

94 

95 @property 

96 def maxBits(self) -> int: 

97 # Docstring inherited from DimensionPacker.maxBits 

98 return self._max_bits 

99 

100 def _pack(self, dataId: DataCoordinate) -> int: 

101 # Docstring inherited from DimensionPacker._pack 

102 return dataId["detector"] 

103 

104 def unpack(self, packedId: int) -> DataCoordinate: 

105 # Docstring inherited from DimensionPacker.unpack 

106 return DataCoordinate.standardize( 

107 { 

108 "instrument": self.fixed["instrument"], 

109 "detector": packedId, 

110 }, 

111 dimensions=self._dimensions, 

112 ) 

113 

114 

115class DimensionTestCase(unittest.TestCase): 

116 """Tests for dimensions. 

117 

118 All tests here rely on the content of ``config/dimensions.yaml``, either 

119 to test that the definitions there are read in properly or just as generic 

120 data for testing various operations. 

121 """ 

122 

123 def setUp(self): 

124 self.universe = DimensionUniverse() 

125 

126 def checkGroupInvariants(self, group: DimensionGroup): 

127 elements = list(group.elements) 

128 for n, element_name in enumerate(elements): 

129 element = self.universe[element_name] 

130 # Ordered comparisons on graphs behave like sets. 

131 self.assertLessEqual(element.minimal_group, group) 

132 # Ordered comparisons on elements correspond to the ordering within 

133 # a DimensionUniverse (topological, with deterministic 

134 # tiebreakers). 

135 for other_name in elements[:n]: 

136 other = self.universe[other_name] 

137 self.assertLess(other, element) 

138 self.assertLessEqual(other, element) 

139 for other_name in elements[n + 1 :]: 

140 other = self.universe[other_name] 

141 self.assertGreater(other, element) 

142 self.assertGreaterEqual(other, element) 

143 if isinstance(element, Dimension): 

144 self.assertEqual(element.minimal_group.required, element.required) 

145 self.assertEqual(self.universe.conform(group.required), group) 

146 self.assertCountEqual( 

147 group.required, 

148 [ 

149 dimension_name 

150 for dimension_name in group.names 

151 if not any( 

152 dimension_name in self.universe[other_name].minimal_group.implied 

153 for other_name in group.elements 

154 ) 

155 ], 

156 ) 

157 self.assertCountEqual(group.implied, group.names - group.required) 

158 self.assertCountEqual(group.names, itertools.chain(group.required, group.implied)) 

159 # Check primary key traversal order: each element should follow any it 

160 # requires, and element that is implied by any other in the graph 

161 # follow at least one of those. 

162 seen: set[str] = set() 

163 for element_name in group.lookup_order: 

164 element = self.universe[element_name] 

165 with self.subTest(required=group.required, implied=group.implied, element=element): 

166 seen.add(element_name) 

167 self.assertLessEqual(element.minimal_group.required, seen) 

168 if element_name in group.implied: 

169 self.assertTrue(any(element_name in self.universe[s].implied for s in seen)) 

170 self.assertCountEqual(seen, group.elements) 

171 

172 def testConfigPresent(self): 

173 config = self.universe.dimensionConfig 

174 self.assertIsInstance(config, DimensionConfig) 

175 

176 def testCompatibility(self): 

177 # Simple check that should always be true. 

178 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

179 

180 # Create a universe like the default universe but with a different 

181 # version number. 

182 clone = self.universe.dimensionConfig.copy() 

183 clone["version"] = clone["version"] + 1_000_000 # High version number 

184 universe_clone = DimensionUniverse(config=clone) 

185 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm: 

186 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

187 self.assertIn("differing versions", "\n".join(cm.output)) 

188 

189 # Create completely incompatible universe. 

190 config = Config( 

191 { 

192 "version": 1, 

193 "namespace": "compat_test", 

194 "skypix": { 

195 "common": "htm7", 

196 "htm": { 

197 "class": "lsst.sphgeom.HtmPixelization", 

198 "max_level": 24, 

199 }, 

200 }, 

201 "elements": { 

202 "A": { 

203 "keys": [ 

204 { 

205 "name": "id", 

206 "type": "int", 

207 } 

208 ], 

209 "storage": { 

210 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

211 }, 

212 }, 

213 "B": { 

214 "keys": [ 

215 { 

216 "name": "id", 

217 "type": "int", 

218 } 

219 ], 

220 "storage": { 

221 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

222 }, 

223 }, 

224 }, 

225 "packers": {}, 

226 } 

227 ) 

228 universe2 = DimensionUniverse(config=config) 

229 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

230 

231 def testVersion(self): 

232 self.assertEqual(self.universe.namespace, "daf_butler") 

233 # Test was added starting at version 2. 

234 self.assertGreaterEqual(self.universe.version, 2) 

235 

236 def testConfigRead(self): 

237 self.assertEqual( 

238 set(self.universe.getStaticDimensions().names), 

239 { 

240 "instrument", 

241 "visit", 

242 "visit_system", 

243 "exposure", 

244 "detector", 

245 "physical_filter", 

246 "band", 

247 "subfilter", 

248 "skymap", 

249 "tract", 

250 "patch", 

251 } 

252 | {f"htm{level}" for level in range(1, 25)} 

253 | {f"healpix{level}" for level in range(1, 18)}, 

254 ) 

255 

256 def testGraphs(self): 

257 self.checkGroupInvariants(self.universe.empty.as_group()) 

258 for element in self.universe.getStaticElements(): 

259 self.checkGroupInvariants(element.minimal_group) 

260 

261 def testInstrumentDimensions(self): 

262 group = self.universe.conform(["exposure", "detector", "visit"]) 

263 self.assertCountEqual( 

264 group.names, 

265 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

266 ) 

267 self.assertCountEqual(group.required, ("instrument", "exposure", "detector", "visit")) 

268 self.assertCountEqual(group.implied, ("physical_filter", "band")) 

269 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition")) 

270 self.assertCountEqual(group.governors, {"instrument"}) 

271 for element in group.elements: 

272 self.assertEqual(self.universe[element].has_own_table, element != "band", element) 

273 self.assertEqual( 

274 self.universe[element].implied_union_target, 

275 "physical_filter" if element == "band" else None, 

276 element, 

277 ) 

278 self.assertEqual( 

279 self.universe[element].defines_relationships, 

280 element 

281 in ("visit", "exposure", "physical_filter", "visit_definition", "visit_system_membership"), 

282 element, 

283 ) 

284 

285 def testCalibrationDimensions(self): 

286 group = self.universe.conform(["physical_filter", "detector"]) 

287 self.assertCountEqual(group.names, ("instrument", "detector", "physical_filter", "band")) 

288 self.assertCountEqual(group.required, ("instrument", "detector", "physical_filter")) 

289 self.assertCountEqual(group.implied, ("band",)) 

290 self.assertCountEqual(group.elements, group.names) 

291 self.assertCountEqual(group.governors, {"instrument"}) 

292 

293 def testObservationDimensions(self): 

294 group = self.universe.conform(["exposure", "detector", "visit"]) 

295 self.assertCountEqual( 

296 group.names, 

297 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

298 ) 

299 self.assertCountEqual(group.required, ("instrument", "detector", "exposure", "visit")) 

300 self.assertCountEqual(group.implied, ("physical_filter", "band")) 

301 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition")) 

302 self.assertCountEqual(group.spatial.names, ("observation_regions",)) 

303 self.assertCountEqual(group.temporal.names, ("observation_timespans",)) 

304 self.assertCountEqual(group.governors, {"instrument"}) 

305 self.assertEqual(group.spatial.names, {"observation_regions"}) 

306 self.assertEqual(group.temporal.names, {"observation_timespans"}) 

307 self.assertEqual(next(iter(group.spatial)).governor, self.universe["instrument"]) 

308 self.assertEqual(next(iter(group.temporal)).governor, self.universe["instrument"]) 

309 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"]) 

310 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"]) 

311 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"]) 

312 self.assertEqual( 

313 self.universe.get_elements_populated_by(self.universe["visit"]), 

314 NamedValueSet( 

315 { 

316 self.universe["visit"], 

317 self.universe["visit_definition"], 

318 self.universe["visit_system_membership"], 

319 self.universe["visit_detector_region"], 

320 } 

321 ), 

322 ) 

323 

324 def testSkyMapDimensions(self): 

325 group = self.universe.conform(["patch"]) 

326 self.assertEqual(group.names, {"skymap", "tract", "patch"}) 

327 self.assertEqual(group.required, {"skymap", "tract", "patch"}) 

328 self.assertEqual(group.implied, set()) 

329 self.assertEqual(group.elements, group.names) 

330 self.assertEqual(group.governors, {"skymap"}) 

331 self.assertEqual(group.spatial.names, {"skymap_regions"}) 

332 self.assertEqual(next(iter(group.spatial)).governor, self.universe["skymap"]) 

333 

334 def testSubsetCalculation(self): 

335 """Test that independent spatial and temporal options are computed 

336 correctly. 

337 """ 

338 group = self.universe.conform(["visit", "detector", "tract", "patch", "htm7", "exposure"]) 

339 self.assertCountEqual(group.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

340 self.assertCountEqual(group.temporal.names, ("observation_timespans",)) 

341 

342 def testSchemaGeneration(self): 

343 tableSpecs: NamedKeyDict[DimensionElement, ddl.TableSpec] = NamedKeyDict({}) 

344 for element in self.universe.getStaticElements(): 

345 if element.has_own_table: 

346 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

347 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

348 ) 

349 for element, tableSpec in tableSpecs.items(): 

350 for dep in element.required: 

351 with self.subTest(element=element.name, dep=dep.name): 

352 if dep != element: 

353 self.assertIn(dep.name, tableSpec.fields) 

354 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

355 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

356 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

357 self.assertFalse(tableSpec.fields[dep.name].nullable) 

358 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

359 else: 

360 self.assertIn(element.primaryKey.name, tableSpec.fields) 

361 self.assertEqual( 

362 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

363 ) 

364 self.assertEqual( 

365 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

366 ) 

367 self.assertEqual( 

368 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

369 ) 

370 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

371 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

372 for dep in element.implied: 

373 with self.subTest(element=element.name, dep=dep.name): 

374 self.assertIn(dep.name, tableSpec.fields) 

375 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

376 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

377 for foreignKey in tableSpec.foreignKeys: 

378 self.assertIn(foreignKey.table, tableSpecs) 

379 self.assertIn(foreignKey.table, element.dimensions) 

380 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

381 for source, target in zip(foreignKey.source, foreignKey.target, strict=True): 

382 self.assertIn(source, tableSpec.fields.names) 

383 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

384 self.assertEqual( 

385 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

386 ) 

387 self.assertEqual( 

388 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

389 ) 

390 self.assertEqual( 

391 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

392 ) 

393 

394 def testPickling(self): 

395 # Pickling and copying should always yield the exact same object within 

396 # a single process (cross-process is impossible to test here). 

397 universe1 = DimensionUniverse() 

398 universe2 = pickle.loads(pickle.dumps(universe1)) 

399 universe3 = copy.copy(universe1) 

400 universe4 = copy.deepcopy(universe1) 

401 self.assertIs(universe1, universe2) 

402 self.assertIs(universe1, universe3) 

403 self.assertIs(universe1, universe4) 

404 for element1 in universe1.getStaticElements(): 

405 element2 = pickle.loads(pickle.dumps(element1)) 

406 self.assertIs(element1, element2) 

407 group1 = element1.minimal_group 

408 group2 = pickle.loads(pickle.dumps(group1)) 

409 self.assertIs(group1, group2) 

410 

411 

412@dataclass 

413class SplitByStateFlags: 

414 """A struct that separates data IDs with different states but the same 

415 values. 

416 """ 

417 

418 minimal: DataCoordinateSequence | None = None 

419 """Data IDs that only contain values for required dimensions. 

420 

421 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

422 if ``minimal.dimensions.implied`` has no elements. 

423 `DataCoordinate.hasRecords()` will always return `False`. 

424 """ 

425 

426 complete: DataCoordinateSequence | None = None 

427 """Data IDs that contain values for all dimensions. 

428 

429 `DataCoordinateSequence.hasFull()` will always `True` and 

430 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

431 """ 

432 

433 expanded: DataCoordinateSequence | None = None 

434 """Data IDs that contain values for all dimensions as well as records. 

435 

436 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

437 always return `True` for this attribute. 

438 """ 

439 

440 def chain(self, n: int | None = None) -> Iterator[DataCoordinate]: 

441 """Iterate over the data IDs of different types. 

442 

443 Parameters 

444 ---------- 

445 n : `int`, optional 

446 If provided (`None` is default), iterate over only the ``nth`` 

447 data ID in each attribute. 

448 

449 Yields 

450 ------ 

451 dataId : `DataCoordinate` 

452 A data ID from one of the attributes in this struct. 

453 """ 

454 if n is None: 

455 s = slice(None, None) 

456 else: 

457 s = slice(n, n + 1) 

458 if self.minimal is not None: 

459 yield from self.minimal[s] 

460 if self.complete is not None: 

461 yield from self.complete[s] 

462 if self.expanded is not None: 

463 yield from self.expanded[s] 

464 

465 

466class DataCoordinateTestCase(unittest.TestCase): 

467 """Test `DataCoordinate`.""" 

468 

469 RANDOM_SEED = 10 

470 

471 @classmethod 

472 def setUpClass(cls): 

473 cls.allDataIds = loadDimensionData() 

474 

475 def setUp(self): 

476 self.rng = Random(self.RANDOM_SEED) 

477 

478 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None): 

479 """Select random data IDs from those loaded from test data. 

480 

481 Parameters 

482 ---------- 

483 n : `int` 

484 Number of data IDs to select. 

485 dataIds : `DataCoordinateSequence`, optional 

486 Data IDs to select from. Defaults to ``self.allDataIds``. 

487 

488 Returns 

489 ------- 

490 selected : `DataCoordinateSequence` 

491 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

492 """ 

493 if dataIds is None: 

494 dataIds = self.allDataIds 

495 return DataCoordinateSequence( 

496 self.rng.sample(dataIds, n), 

497 dimensions=dataIds.dimensions, 

498 hasFull=dataIds.hasFull(), 

499 hasRecords=dataIds.hasRecords(), 

500 check=False, 

501 ) 

502 

503 def randomDimensionSubset(self, n: int = 3, group: DimensionGroup | None = None) -> DimensionGroup: 

504 """Generate a random `DimensionGroup` that has a subset of the 

505 dimensions in a given one. 

506 

507 Parameters 

508 ---------- 

509 n : `int` 

510 Number of dimensions to select, before automatic expansion by 

511 `DimensionGroup`. 

512 group : `DimensionGroup`, optional 

513 Dimensions to select from. Defaults to 

514 ``self.allDataIds.dimensions``. 

515 

516 Returns 

517 ------- 

518 selected : `DimensionGroup` 

519 ``n`` or more dimensions randomly selected from ``group`` with 

520 replacement. 

521 """ 

522 if group is None: 

523 group = self.allDataIds.dimensions 

524 return group.universe.conform(self.rng.sample(list(group.names), max(n, len(group)))) 

525 

526 def splitByStateFlags( 

527 self, 

528 dataIds: DataCoordinateSequence | None = None, 

529 *, 

530 expanded: bool = True, 

531 complete: bool = True, 

532 minimal: bool = True, 

533 ) -> SplitByStateFlags: 

534 """Given a sequence of data IDs, generate new equivalent sequences 

535 containing less information. 

536 

537 Parameters 

538 ---------- 

539 dataIds : `DataCoordinateSequence`, optional. 

540 Data IDs to start from. Defaults to ``self.allDataIds``. 

541 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

542 `True`. 

543 expanded : `bool`, optional 

544 If `True` (default) include the original data IDs that contain all 

545 information in the result. 

546 complete : `bool`, optional 

547 If `True` (default) include data IDs for which ``hasFull()`` 

548 returns `True` but ``hasRecords()`` does not. 

549 minimal : `bool`, optional 

550 If `True` (default) include data IDS that only contain values for 

551 required dimensions, for which ``hasFull()`` may not return `True`. 

552 

553 Returns 

554 ------- 

555 split : `SplitByStateFlags` 

556 A dataclass holding the indicated data IDs in attributes that 

557 correspond to the boolean keyword arguments. 

558 """ 

559 if dataIds is None: 

560 dataIds = self.allDataIds 

561 assert dataIds.hasFull() and dataIds.hasRecords() 

562 result = SplitByStateFlags(expanded=dataIds) 

563 if complete: 

564 result.complete = DataCoordinateSequence( 

565 [ 

566 DataCoordinate.standardize(e.mapping, dimensions=dataIds.dimensions) 

567 for e in result.expanded 

568 ], 

569 dimensions=dataIds.dimensions, 

570 ) 

571 self.assertTrue(result.complete.hasFull()) 

572 self.assertFalse(result.complete.hasRecords()) 

573 if minimal: 

574 result.minimal = DataCoordinateSequence( 

575 [ 

576 DataCoordinate.standardize(e.required, dimensions=dataIds.dimensions) 

577 for e in result.expanded 

578 ], 

579 dimensions=dataIds.dimensions, 

580 ) 

581 self.assertEqual(result.minimal.hasFull(), not dataIds.dimensions.implied) 

582 self.assertFalse(result.minimal.hasRecords()) 

583 if not expanded: 

584 result.expanded = None 

585 return result 

586 

587 def testMappingViews(self): 

588 """Test that the ``mapping`` and ``required`` attributes in 

589 `DataCoordinate` are self-consistent and consistent with the 

590 ``dimensions`` property. 

591 """ 

592 for _ in range(5): 

593 dimensions = self.randomDimensionSubset() 

594 dataIds = self.randomDataIds(n=1).subset(dimensions) 

595 split = self.splitByStateFlags(dataIds) 

596 for dataId in split.chain(): 

597 with self.subTest(dataId=dataId): 

598 self.assertEqual(dataId.required.keys(), dataId.dimensions.required) 

599 self.assertEqual( 

600 list(dataId.required.values()), [dataId[d] for d in dataId.dimensions.required] 

601 ) 

602 self.assertEqual( 

603 list(dataId.required_values), [dataId[d] for d in dataId.dimensions.required] 

604 ) 

605 self.assertEqual(dataId.required.keys(), dataId.dimensions.required) 

606 for dataId in itertools.chain(split.complete, split.expanded): 

607 with self.subTest(dataId=dataId): 

608 self.assertTrue(dataId.hasFull()) 

609 self.assertEqual(dataId.dimensions.names, dataId.mapping.keys()) 

610 self.assertEqual( 

611 list(dataId.mapping.values()), [dataId[k] for k in dataId.mapping.keys()] 

612 ) 

613 

614 def test_pickle(self): 

615 for _ in range(5): 

616 dimensions = self.randomDimensionSubset() 

617 dataIds = self.randomDataIds(n=1).subset(dimensions) 

618 split = self.splitByStateFlags(dataIds) 

619 for data_id in split.chain(): 

620 s = pickle.dumps(data_id) 

621 read_data_id: DataCoordinate = pickle.loads(s) 

622 self.assertEqual(data_id, read_data_id) 

623 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

624 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

625 if data_id.hasFull(): 

626 self.assertEqual(data_id.mapping, read_data_id.mapping) 

627 if data_id.hasRecords(): 

628 for element_name in data_id.dimensions.elements: 

629 self.assertEqual( 

630 data_id.records[element_name], read_data_id.records[element_name] 

631 ) 

632 

633 def test_record_attributes(self): 

634 """Test that dimension records are available as attributes on expanded 

635 data coordinates. 

636 """ 

637 for _ in range(5): 

638 dimensions = self.randomDimensionSubset() 

639 dataIds = self.randomDataIds(n=1).subset(dimensions) 

640 split = self.splitByStateFlags(dataIds) 

641 for data_id in split.expanded: 

642 for element_name in data_id.dimensions.elements: 

643 self.assertIs(getattr(data_id, element_name), data_id.records[element_name]) 

644 self.assertIn(element_name, dir(data_id)) 

645 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

646 data_id.not_a_dimension_name 

647 for data_id in itertools.chain(split.minimal, split.complete): 

648 for element_name in data_id.dimensions.elements: 

649 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

650 getattr(data_id, element_name) 

651 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

652 data_id.not_a_dimension_name 

653 

654 def testEquality(self): 

655 """Test that different `DataCoordinate` instances with different state 

656 flags can be compared with each other and other mappings. 

657 """ 

658 dataIds = self.randomDataIds(n=2) 

659 split = self.splitByStateFlags(dataIds) 

660 # Iterate over all combinations of different states of DataCoordinate, 

661 # with the same underlying data ID values. 

662 for a0, b0 in itertools.combinations(split.chain(0), 2): 

663 self.assertEqual(a0, b0) 

664 # Same thing, for a different data ID value. 

665 for a1, b1 in itertools.combinations(split.chain(1), 2): 

666 self.assertEqual(a1, b1) 

667 # Iterate over all combinations of different states of DataCoordinate, 

668 # with different underlying data ID values. 

669 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

670 self.assertNotEqual(a0, b1) 

671 self.assertNotEqual(a1, b0) 

672 

673 def testStandardize(self): 

674 """Test constructing a DataCoordinate from many different kinds of 

675 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

676 """ 

677 for _ in range(5): 

678 dimensions = self.randomDimensionSubset() 

679 dataIds = self.randomDataIds(n=1).subset(dimensions) 

680 split = self.splitByStateFlags(dataIds) 

681 for dataId in split.chain(): 

682 # Passing in any kind of DataCoordinate alone just returns 

683 # that object. 

684 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

685 # Same if we also explicitly pass the dimensions we want. 

686 self.assertIs(dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions)) 

687 # Same if we pass the dimensions and some irrelevant 

688 # kwargs. 

689 self.assertIs( 

690 dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions, htm7=12) 

691 ) 

692 # Test constructing a new data ID from this one with a 

693 # subset of the dimensions. 

694 # This is not possible for some combinations of 

695 # dimensions if hasFull is False (see 

696 # `DataCoordinate.subset` docs). 

697 newDimensions = self.randomDimensionSubset(n=1, group=dataId.dimensions) 

698 if dataId.hasFull() or dataId.dimensions.required >= newDimensions.required: 

699 newDataIds = [ 

700 dataId.subset(newDimensions), 

701 DataCoordinate.standardize(dataId, dimensions=newDimensions), 

702 DataCoordinate.standardize(dataId, dimensions=newDimensions, htm7=12), 

703 ] 

704 for newDataId in newDataIds: 

705 with self.subTest(newDataId=newDataId, type=type(dataId)): 

706 commonKeys = dataId.dimensions.required & newDataId.dimensions.required 

707 self.assertTrue(commonKeys) 

708 self.assertEqual( 

709 [newDataId[k] for k in commonKeys], 

710 [dataId[k] for k in commonKeys], 

711 ) 

712 # This should never "downgrade" from 

713 # Complete to Minimal or Expanded to Complete. 

714 if dataId.hasRecords(): 

715 self.assertTrue(newDataId.hasRecords()) 

716 if dataId.hasFull(): 

717 self.assertTrue(newDataId.hasFull()) 

718 # Start from a complete data ID, and pass its values in via several 

719 # different ways that should be equivalent. 

720 for dataId in split.complete: 

721 # Split the keys (dimension names) into two random subsets, so 

722 # we can pass some as kwargs below. 

723 keys1 = set(self.rng.sample(list(dataId.dimensions.names), len(dataId.dimensions) // 2)) 

724 keys2 = dataId.dimensions.names - keys1 

725 newCompleteDataIds = [ 

726 DataCoordinate.standardize(dataId.mapping, universe=dataId.universe), 

727 DataCoordinate.standardize(dataId.mapping, dimensions=dataId.dimensions), 

728 DataCoordinate.standardize( 

729 DataCoordinate.make_empty(dataId.dimensions.universe), **dataId.mapping 

730 ), 

731 DataCoordinate.standardize( 

732 DataCoordinate.make_empty(dataId.dimensions.universe), 

733 dimensions=dataId.dimensions, 

734 **dataId.mapping, 

735 ), 

736 DataCoordinate.standardize(**dataId.mapping, universe=dataId.universe), 

737 DataCoordinate.standardize(dimensions=dataId.dimensions, **dataId.mapping), 

738 DataCoordinate.standardize( 

739 {k: dataId[k] for k in keys1}, 

740 universe=dataId.universe, 

741 **{k: dataId[k] for k in keys2}, 

742 ), 

743 DataCoordinate.standardize( 

744 {k: dataId[k] for k in keys1}, 

745 dimensions=dataId.dimensions, 

746 **{k: dataId[k] for k in keys2}, 

747 ), 

748 ] 

749 for newDataId in newCompleteDataIds: 

750 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

751 self.assertEqual(dataId, newDataId) 

752 self.assertTrue(newDataId.hasFull()) 

753 

754 def testUnion(self): 

755 """Test `DataCoordinate.union`.""" 

756 # Make test groups to combine; mostly random, but with a few explicit 

757 # cases to make sure certain edge cases are covered. 

758 groups = [self.randomDimensionSubset(n=2) for i in range(2)] 

759 groups.append(self.allDataIds.universe["visit"].minimal_group) 

760 groups.append(self.allDataIds.universe["detector"].minimal_group) 

761 groups.append(self.allDataIds.universe["physical_filter"].minimal_group) 

762 groups.append(self.allDataIds.universe["band"].minimal_group) 

763 # Iterate over all combinations, including the same graph with itself. 

764 for group1, group2 in itertools.product(groups, repeat=2): 

765 parentDataIds = self.randomDataIds(n=1) 

766 split1 = self.splitByStateFlags(parentDataIds.subset(group1)) 

767 split2 = self.splitByStateFlags(parentDataIds.subset(group2)) 

768 (parentDataId,) = parentDataIds 

769 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

770 unioned = lhs.union(rhs) 

771 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

772 self.assertEqual(unioned.dimensions, group1.union(group2)) 

773 self.assertEqual(unioned, parentDataId.subset(unioned.dimensions)) 

774 if unioned.hasFull(): 

775 self.assertEqual(unioned.subset(lhs.dimensions), lhs) 

776 self.assertEqual(unioned.subset(rhs.dimensions), rhs) 

777 if lhs.hasFull() and rhs.hasFull(): 

778 self.assertTrue(unioned.hasFull()) 

779 if lhs.dimensions >= unioned.dimensions and lhs.hasFull(): 

780 self.assertTrue(unioned.hasFull()) 

781 if lhs.hasRecords(): 

782 self.assertTrue(unioned.hasRecords()) 

783 if rhs.dimensions >= unioned.dimensions and rhs.hasFull(): 

784 self.assertTrue(unioned.hasFull()) 

785 if rhs.hasRecords(): 

786 self.assertTrue(unioned.hasRecords()) 

787 if lhs.dimensions.required | rhs.dimensions.required >= unioned.dimensions.names: 

788 self.assertTrue(unioned.hasFull()) 

789 if ( 

790 lhs.hasRecords() 

791 and rhs.hasRecords() 

792 and lhs.dimensions.elements | rhs.dimensions.elements >= unioned.dimensions.elements 

793 ): 

794 self.assertTrue(unioned.hasRecords()) 

795 

796 def testRegions(self): 

797 """Test that data IDs for a few known dimensions have the expected 

798 regions. 

799 """ 

800 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])): 

801 self.assertIsNotNone(dataId.region) 

802 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"}) 

803 self.assertEqual(dataId.region, dataId.records["visit"].region) 

804 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit", "detector"])): 

805 self.assertIsNotNone(dataId.region) 

806 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"}) 

807 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

808 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["tract"])): 

809 self.assertIsNotNone(dataId.region) 

810 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"}) 

811 self.assertEqual(dataId.region, dataId.records["tract"].region) 

812 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])): 

813 self.assertIsNotNone(dataId.region) 

814 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"}) 

815 self.assertEqual(dataId.region, dataId.records["patch"].region) 

816 for data_id in self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["visit", "tract"])): 

817 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

818 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

819 

820 def testTimespans(self): 

821 """Test that data IDs for a few known dimensions have the expected 

822 timespans. 

823 """ 

824 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])): 

825 self.assertIsNotNone(dataId.timespan) 

826 self.assertEqual(dataId.dimensions.temporal.names, {"observation_timespans"}) 

827 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

828 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

829 # Also test the case for non-temporal DataIds. 

830 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])): 

831 self.assertIsNone(dataId.timespan) 

832 

833 def testIterableStatusFlags(self): 

834 """Test that DataCoordinateSet and DataCoordinateSequence compute 

835 their hasFull and hasRecords flags correctly from their elements. 

836 """ 

837 dataIds = self.randomDataIds(n=10) 

838 split = self.splitByStateFlags(dataIds) 

839 for cls in (DataCoordinateSet, DataCoordinateSequence): 

840 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasFull()) 

841 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasFull()) 

842 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasRecords()) 

843 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasRecords()) 

844 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasFull()) 

845 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasFull()) 

846 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasRecords()) 

847 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasRecords()) 

848 with self.assertRaises(ValueError): 

849 cls(split.complete, dimensions=dataIds.dimensions, hasRecords=True, check=True) 

850 self.assertEqual( 

851 cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasFull(), 

852 not dataIds.dimensions.implied, 

853 ) 

854 self.assertEqual( 

855 cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasFull(), 

856 not dataIds.dimensions.implied, 

857 ) 

858 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasRecords()) 

859 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasRecords()) 

860 with self.assertRaises(ValueError): 

861 cls(split.minimal, dimensions=dataIds.dimensions, hasRecords=True, check=True) 

862 if dataIds.dimensions.implied: 

863 with self.assertRaises(ValueError): 

864 cls(split.minimal, dimensions=dataIds.dimensions, hasFull=True, check=True) 

865 

866 def testSetOperations(self): 

867 """Test for self-consistency across DataCoordinateSet's operations.""" 

868 c = self.randomDataIds(n=10).toSet() 

869 a = self.randomDataIds(n=20).toSet() | c 

870 b = self.randomDataIds(n=20).toSet() | c 

871 # Make sure we don't have a particularly unlucky random seed, since 

872 # that would make a lot of this test uninteresting. 

873 self.assertNotEqual(a, b) 

874 self.assertGreater(len(a), 0) 

875 self.assertGreater(len(b), 0) 

876 # The rest of the tests should not depend on the random seed. 

877 self.assertEqual(a, a) 

878 self.assertNotEqual(a, a.toSequence()) 

879 self.assertEqual(a, a.toSequence().toSet()) 

880 self.assertEqual(a, a.toSequence().toSet()) 

881 self.assertEqual(b, b) 

882 self.assertNotEqual(b, b.toSequence()) 

883 self.assertEqual(b, b.toSequence().toSet()) 

884 self.assertEqual(a & b, a.intersection(b)) 

885 self.assertLessEqual(a & b, a) 

886 self.assertLessEqual(a & b, b) 

887 self.assertEqual(a | b, a.union(b)) 

888 self.assertGreaterEqual(a | b, a) 

889 self.assertGreaterEqual(a | b, b) 

890 self.assertEqual(a - b, a.difference(b)) 

891 self.assertLessEqual(a - b, a) 

892 self.assertLessEqual(b - a, b) 

893 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

894 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

895 

896 def testPackers(self): 

897 (instrument_data_id,) = self.allDataIds.subset( 

898 self.allDataIds.universe.conform(["instrument"]) 

899 ).toSet() 

900 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["detector"])) 

901 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.dimensions) 

902 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

903 self.assertEqual(packed_id, detector_data_id["detector"]) 

904 self.assertEqual(max_bits, packer.maxBits) 

905 self.assertEqual( 

906 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

907 ) 

908 self.assertEqual(packer.pack(detector_data_id), packed_id) 

909 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

910 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

911 

912 

913if __name__ == "__main__": 

914 unittest.main()