Coverage for tests / test_dimensions.py: 11%

493 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:30 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import copy 

29import itertools 

30import math 

31import os 

32import pickle 

33import unittest 

34from collections.abc import Iterator 

35from dataclasses import dataclass 

36from random import Random 

37 

38import pydantic 

39 

40import lsst.sphgeom 

41from lsst.daf.butler import ( 

42 Config, 

43 DataCoordinate, 

44 DataCoordinateSequence, 

45 DataCoordinateSet, 

46 DatasetType, 

47 Dimension, 

48 DimensionConfig, 

49 DimensionElement, 

50 DimensionGroup, 

51 DimensionNameError, 

52 DimensionPacker, 

53 DimensionUniverse, 

54 NamedKeyDict, 

55 NamedValueSet, 

56 ddl, 

57) 

58from lsst.daf.butler.tests.utils import create_populated_sqlite_registry 

59from lsst.daf.butler.timespan_database_representation import TimespanDatabaseRepresentation 

60 

61DIMENSION_DATA_FILE = os.path.normpath( 

62 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

63) 

64 

65 

66def loadDimensionData() -> DataCoordinateSequence: 

67 """Load dimension data from an export file included in the code repository. 

68 

69 Returns 

70 ------- 

71 dataIds : `DataCoordinateSet` 

72 A set containing all data IDs in the export file. 

73 """ 

74 # Create an in-memory SQLite database and Registry just to import the YAML 

75 # data and retrieve it as a set of DataCoordinate objects. 

76 with create_populated_sqlite_registry(DIMENSION_DATA_FILE) as butler: 

77 dimensions = butler.registry.dimensions.conform(["visit", "detector", "tract", "patch"]) 

78 return butler.registry.queryDataIds(dimensions).expanded().toSequence() 

79 

80 

81class ConcreteTestDimensionPacker(DimensionPacker): 

82 """A concrete `DimensionPacker` for testing its base class implementations. 

83 

84 This class just returns the detector ID as-is. 

85 """ 

86 

87 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGroup): 

88 super().__init__(fixed, dimensions) 

89 self._n_detectors = fixed.records["instrument"].detector_max 

90 self._max_bits = (self._n_detectors - 1).bit_length() 

91 

92 @property 

93 def maxBits(self) -> int: 

94 # Docstring inherited from DimensionPacker.maxBits 

95 return self._max_bits 

96 

97 def _pack(self, dataId: DataCoordinate) -> int: 

98 # Docstring inherited from DimensionPacker._pack 

99 return dataId["detector"] 

100 

101 def unpack(self, packedId: int) -> DataCoordinate: 

102 # Docstring inherited from DimensionPacker.unpack 

103 return DataCoordinate.standardize( 

104 { 

105 "instrument": self.fixed["instrument"], 

106 "detector": packedId, 

107 }, 

108 dimensions=self._dimensions, 

109 ) 

110 

111 

112class DimensionTestCase(unittest.TestCase): 

113 """Tests for dimensions. 

114 

115 All tests here rely on the content of ``config/dimensions.yaml``, either 

116 to test that the definitions there are read in properly or just as generic 

117 data for testing various operations. 

118 """ 

119 

120 def setUp(self): 

121 self.universe = DimensionUniverse() 

122 

123 def checkGroupInvariants(self, group: DimensionGroup): 

124 elements = list(group.elements) 

125 for n, element_name in enumerate(elements): 

126 element = self.universe[element_name] 

127 # Ordered comparisons on graphs behave like sets. 

128 self.assertLessEqual(element.minimal_group, group) 

129 # Ordered comparisons on elements correspond to the ordering within 

130 # a DimensionUniverse (topological, with deterministic 

131 # tiebreakers). 

132 for other_name in elements[:n]: 

133 other = self.universe[other_name] 

134 self.assertLess(other, element) 

135 self.assertLessEqual(other, element) 

136 for other_name in elements[n + 1 :]: 

137 other = self.universe[other_name] 

138 self.assertGreater(other, element) 

139 self.assertGreaterEqual(other, element) 

140 if isinstance(element, Dimension): 

141 self.assertEqual(element.minimal_group.required, element.required) 

142 self.assertEqual(self.universe.conform(group.required), group) 

143 self.assertCountEqual( 

144 group.required, 

145 [ 

146 dimension_name 

147 for dimension_name in group.names 

148 if not any( 

149 dimension_name in self.universe[other_name].minimal_group.implied 

150 for other_name in group.elements 

151 ) 

152 ], 

153 ) 

154 self.assertCountEqual(group.implied, group.names - group.required) 

155 self.assertCountEqual(group.names, itertools.chain(group.required, group.implied)) 

156 # Check primary key traversal order: each element should follow any it 

157 # requires, and element that is implied by any other in the graph 

158 # follow at least one of those. 

159 seen: set[str] = set() 

160 for element_name in group.lookup_order: 

161 element = self.universe[element_name] 

162 with self.subTest( 

163 required=repr(group.required), implied=repr(group.implied), element=repr(element) 

164 ): 

165 seen.add(element_name) 

166 self.assertLessEqual(element.minimal_group.required, seen) 

167 if element_name in group.implied: 

168 self.assertTrue(any(element_name in self.universe[s].implied for s in seen)) 

169 self.assertCountEqual(seen, group.elements) 

170 

171 def testConfigPresent(self): 

172 config = self.universe.dimensionConfig 

173 self.assertIsInstance(config, DimensionConfig) 

174 

175 def test_group_union(self) -> None: 

176 """Test unions of DimensionGroups.""" 

177 a = self.universe.conform(["visit"]) 

178 b = self.universe.conform(["detector"]) 

179 self.assertIs(a.union(b), self.universe.conform(["visit", "detector"])) 

180 self.assertIs(DimensionGroup.union(a, b), self.universe.conform(["visit", "detector"])) 

181 self.assertIs(DimensionGroup.union(universe=self.universe), self.universe.empty) 

182 with self.assertRaises(TypeError): 

183 DimensionGroup.union() 

184 

185 def test_group_set_ops(self) -> None: 

186 """Test set operations on DimensionGroups.""" 

187 a = self.universe.conform(["visit"]) 

188 b = self.universe.conform(["detector"]) 

189 c = self.universe.conform(["visit", "detector"]) 

190 d = self.universe.conform(["physical_filter"]) 

191 e = self.universe.conform(["detector", "physical_filter"]) 

192 self.assertEqual(a.union(b), c) 

193 self.assertEqual(a.union(d), a) 

194 self.assertEqual(b.union(d), e) 

195 self.assertEqual(a.difference(b), a) 

196 self.assertEqual(a.difference(c), self.universe.empty) 

197 self.assertEqual(a.difference(d), a) 

198 self.assertEqual(c.difference(a), b) 

199 self.assertEqual(a.intersection(c), a) 

200 self.assertEqual(a.intersection(d), d) 

201 self.assertEqual(a.intersection(e), d) 

202 

203 def testCompatibility(self): 

204 # Simple check that should always be true. 

205 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

206 

207 # Create a universe like the default universe but with a different 

208 # version number. 

209 clone = self.universe.dimensionConfig.copy() 

210 clone["version"] = clone["version"] + 1_000_000 # High version number 

211 universe_clone = DimensionUniverse(config=clone) 

212 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm: 

213 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

214 self.assertIn("differing versions", "\n".join(cm.output)) 

215 

216 # Create completely incompatible universe. 

217 config = Config( 

218 { 

219 "version": 1, 

220 "namespace": "compat_test", 

221 "skypix": { 

222 "common": "htm7", 

223 "htm": { 

224 "class": "lsst.sphgeom.HtmPixelization", 

225 "max_level": 24, 

226 }, 

227 }, 

228 "elements": { 

229 "A": { 

230 "keys": [ 

231 { 

232 "name": "id", 

233 "type": "int", 

234 } 

235 ], 

236 "storage": { 

237 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

238 }, 

239 }, 

240 "B": { 

241 "keys": [ 

242 { 

243 "name": "id", 

244 "type": "int", 

245 } 

246 ], 

247 "storage": { 

248 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

249 }, 

250 }, 

251 }, 

252 "packers": {}, 

253 } 

254 ) 

255 universe2 = DimensionUniverse(config=config) 

256 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

257 

258 def testVersion(self): 

259 self.assertEqual(self.universe.namespace, "daf_butler") 

260 # Test was added starting at version 2. 

261 self.assertGreaterEqual(self.universe.version, 2) 

262 

263 def testConfigRead(self): 

264 self.assertEqual( 

265 set(self.universe.getStaticDimensions().names), 

266 { 

267 "instrument", 

268 "visit", 

269 "visit_system", 

270 "exposure", 

271 "detector", 

272 "group", 

273 "day_obs", 

274 "physical_filter", 

275 "band", 

276 "subfilter", 

277 "skymap", 

278 "tract", 

279 "patch", 

280 "ssp_hypothesis_table", 

281 "ssp_hypothesis_bundle", 

282 "ssp_balanced_index", 

283 } 

284 | {f"htm{level}" for level in range(1, 25)} 

285 | {f"healpix{level}" for level in range(1, 18)}, 

286 ) 

287 

288 def testGraphs(self): 

289 self.checkGroupInvariants(self.universe.empty) 

290 for element in self.universe.getStaticElements(): 

291 self.checkGroupInvariants(element.minimal_group) 

292 

293 def testInstrumentDimensions(self): 

294 group = self.universe.conform(["exposure", "detector", "visit"]) 

295 self.assertCountEqual( 

296 group.names, 

297 ("instrument", "exposure", "detector", "visit", "physical_filter", "band", "group", "day_obs"), 

298 ) 

299 self.assertCountEqual(group.required, ("instrument", "exposure", "detector", "visit")) 

300 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs")) 

301 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition")) 

302 self.assertCountEqual(group.governors, {"instrument"}) 

303 for element in group.elements: 

304 self.assertEqual(self.universe[element].has_own_table, element != "band", element) 

305 self.assertEqual( 

306 self.universe[element].implied_union_target, 

307 "physical_filter" if element == "band" else None, 

308 element, 

309 ) 

310 self.assertEqual( 

311 self.universe[element].defines_relationships, 

312 element 

313 in ("visit", "exposure", "physical_filter", "visit_definition", "visit_system_membership"), 

314 element, 

315 ) 

316 

317 def testCalibrationDimensions(self): 

318 group = self.universe.conform(["physical_filter", "detector"]) 

319 self.assertCountEqual(group.names, ("instrument", "detector", "physical_filter", "band")) 

320 self.assertCountEqual(group.required, ("instrument", "detector", "physical_filter")) 

321 self.assertCountEqual(group.implied, ("band",)) 

322 self.assertCountEqual(group.elements, group.names) 

323 self.assertCountEqual(group.governors, {"instrument"}) 

324 self.assertIsNone(group.region_dimension) 

325 self.assertIsNone(group.timespan_dimension) 

326 

327 def testObservationDimensions(self): 

328 group = self.universe.conform(["exposure", "detector", "visit"]) 

329 self.assertCountEqual( 

330 group.names, 

331 ("instrument", "detector", "visit", "exposure", "physical_filter", "band", "group", "day_obs"), 

332 ) 

333 self.assertCountEqual(group.required, ("instrument", "detector", "exposure", "visit")) 

334 self.assertCountEqual(group.implied, ("physical_filter", "band", "group", "day_obs")) 

335 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition")) 

336 self.assertCountEqual(group.spatial.names, ("observation_regions",)) 

337 self.assertCountEqual(group.temporal.names, ("observation_timespans",)) 

338 self.assertCountEqual(group.governors, {"instrument"}) 

339 self.assertEqual(group.region_dimension, "visit_detector_region") 

340 self.assertEqual(group.timespan_dimension, "exposure") 

341 self.assertEqual(group.spatial.names, {"observation_regions"}) 

342 self.assertEqual(group.temporal.names, {"observation_timespans"}) 

343 self.assertEqual(next(iter(group.spatial)).governor, self.universe["instrument"]) 

344 self.assertEqual(next(iter(group.temporal)).governor, self.universe["instrument"]) 

345 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"]) 

346 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"]) 

347 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"]) 

348 self.assertEqual( 

349 self.universe.get_elements_populated_by(self.universe["visit"]), 

350 NamedValueSet( 

351 { 

352 self.universe["visit"], 

353 self.universe["visit_definition"], 

354 self.universe["visit_system_membership"], 

355 self.universe["visit_detector_region"], 

356 } 

357 ), 

358 ) 

359 

360 def testSkyMapDimensions(self): 

361 group = self.universe.conform(["patch"]) 

362 self.assertEqual(group.names, {"skymap", "tract", "patch"}) 

363 self.assertEqual(group.required, {"skymap", "tract", "patch"}) 

364 self.assertEqual(group.implied, set()) 

365 self.assertEqual(group.elements, group.names) 

366 self.assertEqual(group.governors, {"skymap"}) 

367 self.assertEqual(group.spatial.names, {"skymap_regions"}) 

368 self.assertEqual(next(iter(group.spatial)).governor, self.universe["skymap"]) 

369 

370 def testSubsetCalculation(self): 

371 """Test that independent spatial and temporal options are computed 

372 correctly. 

373 """ 

374 group = self.universe.conform(["visit", "detector", "tract", "patch", "htm7", "exposure"]) 

375 self.assertCountEqual(group.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

376 self.assertCountEqual(group.temporal.names, ("observation_timespans",)) 

377 self.assertEqual(group.timespan_dimension, "exposure") 

378 # Can not choose between visit_detector_region or htm7 or tract/patch. 

379 self.assertIsNone(group.region_dimension) 

380 

381 def testSchemaGeneration(self): 

382 tableSpecs: NamedKeyDict[DimensionElement, ddl.TableSpec] = NamedKeyDict({}) 

383 for element in self.universe.getStaticElements(): 

384 if element.has_own_table: 

385 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

386 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

387 ) 

388 for element, tableSpec in tableSpecs.items(): 

389 for dep in element.required: 

390 with self.subTest(element=element.name, dep=dep.name): 

391 if dep != element: 

392 self.assertIn(dep.name, tableSpec.fields) 

393 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

394 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

395 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

396 self.assertFalse(tableSpec.fields[dep.name].nullable) 

397 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

398 else: 

399 self.assertIn(element.primaryKey.name, tableSpec.fields) 

400 self.assertEqual( 

401 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

402 ) 

403 self.assertEqual( 

404 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

405 ) 

406 self.assertEqual( 

407 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

408 ) 

409 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

410 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

411 for dep in element.implied: 

412 with self.subTest(element=element.name, dep=dep.name): 

413 self.assertIn(dep.name, tableSpec.fields) 

414 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

415 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

416 for foreignKey in tableSpec.foreignKeys: 

417 self.assertIn(foreignKey.table, tableSpecs) 

418 self.assertIn(foreignKey.table, element.dimensions) 

419 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

420 for source, target in zip(foreignKey.source, foreignKey.target, strict=True): 

421 self.assertIn(source, tableSpec.fields.names) 

422 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

423 self.assertEqual( 

424 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

425 ) 

426 self.assertEqual( 

427 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

428 ) 

429 self.assertEqual( 

430 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

431 ) 

432 

433 def testPickling(self): 

434 # Pickling and copying should always yield the exact same object within 

435 # a single process (cross-process is impossible to test here). 

436 universe1 = DimensionUniverse() 

437 universe2 = pickle.loads(pickle.dumps(universe1)) 

438 universe3 = copy.copy(universe1) 

439 universe4 = copy.deepcopy(universe1) 

440 self.assertIs(universe1, universe2) 

441 self.assertIs(universe1, universe3) 

442 self.assertIs(universe1, universe4) 

443 for element1 in universe1.getStaticElements(): 

444 element2 = pickle.loads(pickle.dumps(element1)) 

445 self.assertIs(element1, element2) 

446 group1 = element1.minimal_group 

447 group2 = pickle.loads(pickle.dumps(group1)) 

448 self.assertIs(group1, group2) 

449 

450 def testSerialization(self): 

451 # Check that dataset types round-trip correctly through serialization. 

452 dataset_type = DatasetType( 

453 "flat", 

454 dimensions=["instrument", "detector", "physical_filter", "band"], 

455 isCalibration=True, 

456 universe=self.universe, 

457 storageClass="int", 

458 ) 

459 roundtripped = DatasetType.from_simple(dataset_type.to_simple(), universe=self.universe) 

460 

461 self.assertEqual(dataset_type, roundtripped) 

462 

463 

464@dataclass 

465class SplitByStateFlags: 

466 """A struct that separates data IDs with different states but the same 

467 values. 

468 """ 

469 

470 minimal: DataCoordinateSequence | None = None 

471 """Data IDs that only contain values for required dimensions. 

472 

473 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

474 if ``minimal.dimensions.implied`` has no elements. 

475 `DataCoordinate.hasRecords()` will always return `False`. 

476 """ 

477 

478 complete: DataCoordinateSequence | None = None 

479 """Data IDs that contain values for all dimensions. 

480 

481 `DataCoordinateSequence.hasFull()` will always `True` and 

482 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

483 """ 

484 

485 expanded: DataCoordinateSequence | None = None 

486 """Data IDs that contain values for all dimensions as well as records. 

487 

488 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

489 always return `True` for this attribute. 

490 """ 

491 

492 def chain(self, n: int | None = None) -> Iterator[DataCoordinate]: 

493 """Iterate over the data IDs of different types. 

494 

495 Parameters 

496 ---------- 

497 n : `int`, optional 

498 If provided (`None` is default), iterate over only the ``nth`` 

499 data ID in each attribute. 

500 

501 Yields 

502 ------ 

503 dataId : `DataCoordinate` 

504 A data ID from one of the attributes in this struct. 

505 """ 

506 if n is None: 

507 s = slice(None, None) 

508 else: 

509 s = slice(n, n + 1) 

510 if self.minimal is not None: 

511 yield from self.minimal[s] 

512 if self.complete is not None: 

513 yield from self.complete[s] 

514 if self.expanded is not None: 

515 yield from self.expanded[s] 

516 

517 

518class DataCoordinateTestCase(unittest.TestCase): 

519 """Test `DataCoordinate`.""" 

520 

521 RANDOM_SEED = 10 

522 

523 @classmethod 

524 def setUpClass(cls): 

525 cls.allDataIds = loadDimensionData() 

526 

527 def setUp(self): 

528 self.rng = Random(self.RANDOM_SEED) 

529 

530 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None): 

531 """Select random data IDs from those loaded from test data. 

532 

533 Parameters 

534 ---------- 

535 n : `int` 

536 Number of data IDs to select. 

537 dataIds : `DataCoordinateSequence`, optional 

538 Data IDs to select from. Defaults to ``self.allDataIds``. 

539 

540 Returns 

541 ------- 

542 selected : `DataCoordinateSequence` 

543 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

544 """ 

545 if dataIds is None: 

546 dataIds = self.allDataIds 

547 return DataCoordinateSequence( 

548 self.rng.sample(dataIds, n), 

549 dimensions=dataIds.dimensions, 

550 hasFull=dataIds.hasFull(), 

551 hasRecords=dataIds.hasRecords(), 

552 check=False, 

553 ) 

554 

555 def randomDimensionSubset(self, n: int = 3, group: DimensionGroup | None = None) -> DimensionGroup: 

556 """Generate a random `DimensionGroup` that has a subset of the 

557 dimensions in a given one. 

558 

559 Parameters 

560 ---------- 

561 n : `int` 

562 Number of dimensions to select, before automatic expansion by 

563 `DimensionGroup`. 

564 group : `DimensionGroup`, optional 

565 Dimensions to select from. Defaults to 

566 ``self.allDataIds.dimensions``. 

567 

568 Returns 

569 ------- 

570 selected : `DimensionGroup` 

571 ``n`` or more dimensions randomly selected from ``group`` with 

572 replacement. 

573 """ 

574 if group is None: 

575 group = self.allDataIds.dimensions 

576 return group.universe.conform(self.rng.sample(list(group.names), max(n, len(group)))) 

577 

578 def splitByStateFlags( 

579 self, 

580 dataIds: DataCoordinateSequence | None = None, 

581 *, 

582 expanded: bool = True, 

583 complete: bool = True, 

584 minimal: bool = True, 

585 ) -> SplitByStateFlags: 

586 """Given a sequence of data IDs, generate new equivalent sequences 

587 containing less information. 

588 

589 Parameters 

590 ---------- 

591 dataIds : `DataCoordinateSequence`, optional. 

592 Data IDs to start from. Defaults to ``self.allDataIds``. 

593 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

594 `True`. 

595 expanded : `bool`, optional 

596 If `True` (default) include the original data IDs that contain all 

597 information in the result. 

598 complete : `bool`, optional 

599 If `True` (default) include data IDs for which ``hasFull()`` 

600 returns `True` but ``hasRecords()`` does not. 

601 minimal : `bool`, optional 

602 If `True` (default) include data IDS that only contain values for 

603 required dimensions, for which ``hasFull()`` may not return `True`. 

604 

605 Returns 

606 ------- 

607 split : `SplitByStateFlags` 

608 A dataclass holding the indicated data IDs in attributes that 

609 correspond to the boolean keyword arguments. 

610 """ 

611 if dataIds is None: 

612 dataIds = self.allDataIds 

613 assert dataIds.hasFull() and dataIds.hasRecords() 

614 result = SplitByStateFlags(expanded=dataIds) 

615 if complete: 

616 result.complete = DataCoordinateSequence( 

617 [ 

618 DataCoordinate.standardize(e.mapping, dimensions=dataIds.dimensions) 

619 for e in result.expanded 

620 ], 

621 dimensions=dataIds.dimensions, 

622 ) 

623 self.assertTrue(result.complete.hasFull()) 

624 self.assertFalse(result.complete.hasRecords()) 

625 if minimal: 

626 result.minimal = DataCoordinateSequence( 

627 [ 

628 DataCoordinate.standardize(e.required, dimensions=dataIds.dimensions) 

629 for e in result.expanded 

630 ], 

631 dimensions=dataIds.dimensions, 

632 ) 

633 self.assertEqual(result.minimal.hasFull(), not dataIds.dimensions.implied) 

634 self.assertFalse(result.minimal.hasRecords()) 

635 if not expanded: 

636 result.expanded = None 

637 return result 

638 

639 def testMappingViews(self): 

640 """Test that the ``mapping`` and ``required`` attributes in 

641 `DataCoordinate` are self-consistent and consistent with the 

642 ``dimensions`` property. 

643 """ 

644 for _ in range(5): 

645 dimensions = self.randomDimensionSubset() 

646 dataIds = self.randomDataIds(n=1).subset(dimensions) 

647 split = self.splitByStateFlags(dataIds) 

648 for dataId in split.chain(): 

649 with self.subTest(dataId=repr(dataId)): 

650 self.assertEqual(dataId.required.keys(), dataId.dimensions.required) 

651 self.assertEqual( 

652 list(dataId.required.values()), [dataId[d] for d in dataId.dimensions.required] 

653 ) 

654 self.assertEqual( 

655 list(dataId.required_values), [dataId[d] for d in dataId.dimensions.required] 

656 ) 

657 self.assertEqual(dataId.required.keys(), dataId.dimensions.required) 

658 for dataId in itertools.chain(split.complete, split.expanded): 

659 with self.subTest(dataId=repr(dataId)): 

660 self.assertTrue(dataId.hasFull()) 

661 self.assertEqual(dataId.dimensions.names, dataId.mapping.keys()) 

662 self.assertEqual( 

663 list(dataId.mapping.values()), [dataId[k] for k in dataId.mapping.keys()] 

664 ) 

665 

666 def test_pickle(self): 

667 for _ in range(5): 

668 dimensions = self.randomDimensionSubset() 

669 dataIds = self.randomDataIds(n=1).subset(dimensions) 

670 split = self.splitByStateFlags(dataIds) 

671 for data_id in split.chain(): 

672 s = pickle.dumps(data_id) 

673 read_data_id: DataCoordinate = pickle.loads(s) 

674 self.assertEqual(data_id, read_data_id) 

675 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

676 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

677 if data_id.hasFull(): 

678 self.assertEqual(data_id.mapping, read_data_id.mapping) 

679 if data_id.hasRecords(): 

680 for element_name in data_id.dimensions.elements: 

681 self.assertEqual( 

682 data_id.records[element_name], read_data_id.records[element_name] 

683 ) 

684 

685 def test_record_attributes(self): 

686 """Test that dimension records are available as attributes on expanded 

687 data coordinates. 

688 """ 

689 for _ in range(5): 

690 dimensions = self.randomDimensionSubset() 

691 dataIds = self.randomDataIds(n=1).subset(dimensions) 

692 split = self.splitByStateFlags(dataIds) 

693 for data_id in split.expanded: 

694 for element_name in data_id.dimensions.elements: 

695 self.assertIs(getattr(data_id, element_name), data_id.records[element_name]) 

696 self.assertIn(element_name, dir(data_id)) 

697 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

698 data_id.not_a_dimension_name 

699 for data_id in itertools.chain(split.minimal, split.complete): 

700 for element_name in data_id.dimensions.elements: 

701 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

702 getattr(data_id, element_name) 

703 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

704 data_id.not_a_dimension_name 

705 

706 def testEquality(self): 

707 """Test that different `DataCoordinate` instances with different state 

708 flags can be compared with each other and other mappings. 

709 """ 

710 dataIds = self.randomDataIds(n=2) 

711 split = self.splitByStateFlags(dataIds) 

712 # Iterate over all combinations of different states of DataCoordinate, 

713 # with the same underlying data ID values. 

714 for a0, b0 in itertools.combinations(split.chain(0), 2): 

715 self.assertEqual(a0, b0) 

716 # Same thing, for a different data ID value. 

717 for a1, b1 in itertools.combinations(split.chain(1), 2): 

718 self.assertEqual(a1, b1) 

719 # Iterate over all combinations of different states of DataCoordinate, 

720 # with different underlying data ID values. 

721 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

722 self.assertNotEqual(a0, b1) 

723 self.assertNotEqual(a1, b0) 

724 

725 def testStandardize(self): 

726 """Test constructing a DataCoordinate from many different kinds of 

727 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

728 """ 

729 for _ in range(5): 

730 dimensions = self.randomDimensionSubset() 

731 dataIds = self.randomDataIds(n=1).subset(dimensions) 

732 split = self.splitByStateFlags(dataIds) 

733 for dataId in split.chain(): 

734 # Passing in any kind of DataCoordinate alone just returns 

735 # that object. 

736 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

737 # Same if we also explicitly pass the dimensions we want. 

738 self.assertIs(dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions)) 

739 # Same if we pass the dimensions and some irrelevant 

740 # kwargs. 

741 self.assertIs( 

742 dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions, htm7=12) 

743 ) 

744 # Test constructing a new data ID from this one with a 

745 # subset of the dimensions. 

746 # This is not possible for some combinations of 

747 # dimensions if hasFull is False (see 

748 # `DataCoordinate.subset` docs). 

749 newDimensions = self.randomDimensionSubset(n=1, group=dataId.dimensions) 

750 if dataId.hasFull() or dataId.dimensions.required >= newDimensions.required: 

751 newDataIds = [ 

752 dataId.subset(newDimensions), 

753 DataCoordinate.standardize(dataId, dimensions=newDimensions), 

754 DataCoordinate.standardize(dataId, dimensions=newDimensions, htm7=12), 

755 ] 

756 for newDataId in newDataIds: 

757 with self.subTest(newDataId=repr(newDataId), type=str(type(dataId))): 

758 commonKeys = dataId.dimensions.required & newDataId.dimensions.required 

759 self.assertTrue(commonKeys) 

760 self.assertEqual( 

761 [newDataId[k] for k in commonKeys], 

762 [dataId[k] for k in commonKeys], 

763 ) 

764 # This should never "downgrade" from 

765 # Complete to Minimal or Expanded to Complete. 

766 if dataId.hasRecords(): 

767 self.assertTrue(newDataId.hasRecords()) 

768 if dataId.hasFull(): 

769 self.assertTrue(newDataId.hasFull()) 

770 # Start from a complete data ID, and pass its values in via several 

771 # different ways that should be equivalent. 

772 for dataId in split.complete: 

773 # Split the keys (dimension names) into two random subsets, so 

774 # we can pass some as kwargs below. 

775 keys1 = set(self.rng.sample(list(dataId.dimensions.names), len(dataId.dimensions) // 2)) 

776 keys2 = dataId.dimensions.names - keys1 

777 newCompleteDataIds = [ 

778 DataCoordinate.standardize(dataId.mapping, universe=dataId.universe), 

779 DataCoordinate.standardize(dataId.mapping, dimensions=dataId.dimensions), 

780 DataCoordinate.standardize( 

781 DataCoordinate.make_empty(dataId.dimensions.universe), **dataId.mapping 

782 ), 

783 DataCoordinate.standardize( 

784 DataCoordinate.make_empty(dataId.dimensions.universe), 

785 dimensions=dataId.dimensions, 

786 **dataId.mapping, 

787 ), 

788 DataCoordinate.standardize(**dataId.mapping, universe=dataId.universe), 

789 DataCoordinate.standardize(dimensions=dataId.dimensions, **dataId.mapping), 

790 DataCoordinate.standardize( 

791 {k: dataId[k] for k in keys1}, 

792 universe=dataId.universe, 

793 **{k: dataId[k] for k in keys2}, 

794 ), 

795 DataCoordinate.standardize( 

796 {k: dataId[k] for k in keys1}, 

797 dimensions=dataId.dimensions, 

798 **{k: dataId[k] for k in keys2}, 

799 ), 

800 ] 

801 for newDataId in newCompleteDataIds: 

802 with self.subTest(dataId=repr(dataId), newDataId=repr(newDataId), type=str(type(dataId))): 

803 self.assertEqual(dataId, newDataId) 

804 self.assertTrue(newDataId.hasFull()) 

805 

806 coord = DataCoordinate.standardize({"instrument": "HSC"}, universe=self.allDataIds.universe) 

807 with self.assertRaises(DimensionNameError): 

808 coord.subset(["detector"]) 

809 

810 def testUnion(self): 

811 """Test `DataCoordinate.union`.""" 

812 # Make test groups to combine; mostly random, but with a few explicit 

813 # cases to make sure certain edge cases are covered. 

814 groups = [self.randomDimensionSubset(n=2) for i in range(2)] 

815 groups.append(self.allDataIds.universe["visit"].minimal_group) 

816 groups.append(self.allDataIds.universe["detector"].minimal_group) 

817 groups.append(self.allDataIds.universe["physical_filter"].minimal_group) 

818 groups.append(self.allDataIds.universe["band"].minimal_group) 

819 # Iterate over all combinations, including the same graph with itself. 

820 for group1, group2 in itertools.product(groups, repeat=2): 

821 parentDataIds = self.randomDataIds(n=1) 

822 split1 = self.splitByStateFlags(parentDataIds.subset(group1)) 

823 split2 = self.splitByStateFlags(parentDataIds.subset(group2)) 

824 (parentDataId,) = parentDataIds 

825 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

826 unioned = lhs.union(rhs) 

827 with self.subTest(lhs=repr(lhs), rhs=repr(rhs), unioned=repr(unioned)): 

828 self.assertEqual(unioned.dimensions, group1.union(group2)) 

829 self.assertEqual(unioned, parentDataId.subset(unioned.dimensions)) 

830 if unioned.hasFull(): 

831 self.assertEqual(unioned.subset(lhs.dimensions), lhs) 

832 self.assertEqual(unioned.subset(rhs.dimensions), rhs) 

833 if lhs.hasFull() and rhs.hasFull(): 

834 self.assertTrue(unioned.hasFull()) 

835 if lhs.dimensions >= unioned.dimensions and lhs.hasFull(): 

836 self.assertTrue(unioned.hasFull()) 

837 if lhs.hasRecords(): 

838 self.assertTrue(unioned.hasRecords()) 

839 if rhs.dimensions >= unioned.dimensions and rhs.hasFull(): 

840 self.assertTrue(unioned.hasFull()) 

841 if rhs.hasRecords(): 

842 self.assertTrue(unioned.hasRecords()) 

843 if lhs.dimensions.required | rhs.dimensions.required >= unioned.dimensions.names: 

844 self.assertTrue(unioned.hasFull()) 

845 if ( 

846 lhs.hasRecords() 

847 and rhs.hasRecords() 

848 and lhs.dimensions.elements | rhs.dimensions.elements >= unioned.dimensions.elements 

849 ): 

850 self.assertTrue(unioned.hasRecords()) 

851 

852 def testRegions(self): 

853 """Test that data IDs for a few known dimensions have the expected 

854 regions. 

855 """ 

856 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])): 

857 self.assertIsNotNone(dataId.region) 

858 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"}) 

859 self.assertEqual(dataId.region, dataId.records["visit"].region) 

860 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit", "detector"])): 

861 self.assertIsNotNone(dataId.region) 

862 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"}) 

863 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

864 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["tract"])): 

865 self.assertIsNotNone(dataId.region) 

866 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"}) 

867 self.assertEqual(dataId.region, dataId.records["tract"].region) 

868 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])): 

869 self.assertIsNotNone(dataId.region) 

870 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"}) 

871 self.assertEqual(dataId.region, dataId.records["patch"].region) 

872 for data_id in self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["visit", "tract"])): 

873 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

874 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

875 

876 def testTimespans(self): 

877 """Test that data IDs for a few known dimensions have the expected 

878 timespans. 

879 """ 

880 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])): 

881 self.assertIsNotNone(dataId.timespan) 

882 self.assertEqual(dataId.dimensions.temporal.names, {"observation_timespans"}) 

883 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

884 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

885 # Also test the case for non-temporal DataIds. 

886 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])): 

887 self.assertIsNone(dataId.timespan) 

888 

889 def testIterableStatusFlags(self): 

890 """Test that DataCoordinateSet and DataCoordinateSequence compute 

891 their hasFull and hasRecords flags correctly from their elements. 

892 """ 

893 dataIds = self.randomDataIds(n=10) 

894 split = self.splitByStateFlags(dataIds) 

895 for cls in (DataCoordinateSet, DataCoordinateSequence): 

896 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasFull()) 

897 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasFull()) 

898 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasRecords()) 

899 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasRecords()) 

900 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasFull()) 

901 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasFull()) 

902 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasRecords()) 

903 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasRecords()) 

904 with self.assertRaises(ValueError): 

905 cls(split.complete, dimensions=dataIds.dimensions, hasRecords=True, check=True) 

906 self.assertEqual( 

907 cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasFull(), 

908 not dataIds.dimensions.implied, 

909 ) 

910 self.assertEqual( 

911 cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasFull(), 

912 not dataIds.dimensions.implied, 

913 ) 

914 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasRecords()) 

915 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasRecords()) 

916 with self.assertRaises(ValueError): 

917 cls(split.minimal, dimensions=dataIds.dimensions, hasRecords=True, check=True) 

918 if dataIds.dimensions.implied: 

919 with self.assertRaises(ValueError): 

920 cls(split.minimal, dimensions=dataIds.dimensions, hasFull=True, check=True) 

921 

922 def testSetOperations(self): 

923 """Test for self-consistency across DataCoordinateSet's operations.""" 

924 c = self.randomDataIds(n=10).toSet() 

925 a = self.randomDataIds(n=20).toSet() | c 

926 b = self.randomDataIds(n=20).toSet() | c 

927 # Make sure we don't have a particularly unlucky random seed, since 

928 # that would make a lot of this test uninteresting. 

929 self.assertNotEqual(a, b) 

930 self.assertGreater(len(a), 0) 

931 self.assertGreater(len(b), 0) 

932 # The rest of the tests should not depend on the random seed. 

933 self.assertEqual(a, a) 

934 self.assertNotEqual(a, a.toSequence()) 

935 self.assertEqual(a, a.toSequence().toSet()) 

936 self.assertEqual(a, a.toSequence().toSet()) 

937 self.assertEqual(b, b) 

938 self.assertNotEqual(b, b.toSequence()) 

939 self.assertEqual(b, b.toSequence().toSet()) 

940 self.assertEqual(a & b, a.intersection(b)) 

941 self.assertLessEqual(a & b, a) 

942 self.assertLessEqual(a & b, b) 

943 self.assertEqual(a | b, a.union(b)) 

944 self.assertGreaterEqual(a | b, a) 

945 self.assertGreaterEqual(a | b, b) 

946 self.assertEqual(a - b, a.difference(b)) 

947 self.assertLessEqual(a - b, a) 

948 self.assertLessEqual(b - a, b) 

949 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

950 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

951 

952 def testPackers(self): 

953 (instrument_data_id,) = self.allDataIds.subset( 

954 self.allDataIds.universe.conform(["instrument"]) 

955 ).toSet() 

956 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["detector"])) 

957 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.dimensions) 

958 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

959 self.assertEqual(packed_id, detector_data_id["detector"]) 

960 self.assertEqual(max_bits, packer.maxBits) 

961 self.assertEqual( 

962 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

963 ) 

964 self.assertEqual(packer.pack(detector_data_id), packed_id) 

965 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

966 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

967 

968 def test_dimension_group_pydantic(self): 

969 """Test that DimensionGroup round-trips through Pydantic as long as 

970 it's given the universe when validated. 

971 """ 

972 dimensions = self.allDataIds.dimensions 

973 adapter = pydantic.TypeAdapter(DimensionGroup) 

974 json_str = adapter.dump_json(dimensions) 

975 python_data = adapter.dump_python(dimensions) 

976 self.assertEqual( 

977 dimensions, adapter.validate_json(json_str, context=dict(universe=dimensions.universe)) 

978 ) 

979 self.assertEqual( 

980 dimensions, adapter.validate_python(python_data, context=dict(universe=dimensions.universe)) 

981 ) 

982 self.assertEqual(dimensions, adapter.validate_python(dimensions)) 

983 

984 def test_dimension_element_pydantic(self): 

985 """Test that DimensionElement round-trips through Pydantic as long as 

986 it's given the universe when validated. 

987 """ 

988 element = self.allDataIds.universe["visit"] 

989 adapter = pydantic.TypeAdapter(DimensionElement) 

990 json_str = adapter.dump_json(element) 

991 python_data = adapter.dump_python(element) 

992 self.assertEqual(element, adapter.validate_json(json_str, context=dict(universe=element.universe))) 

993 self.assertEqual( 

994 element, adapter.validate_python(python_data, context=dict(universe=element.universe)) 

995 ) 

996 self.assertEqual(element, adapter.validate_python(element)) 

997 

998 

999if __name__ == "__main__": 

1000 unittest.main()