Coverage for tests/test_dimensions.py: 11%

440 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-01 11:00 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import copy 

29import itertools 

30import math 

31import os 

32import pickle 

33import unittest 

34from collections.abc import Iterator 

35from dataclasses import dataclass 

36from random import Random 

37 

38import lsst.sphgeom 

39from lsst.daf.butler import ( 

40 Config, 

41 DataCoordinate, 

42 DataCoordinateSequence, 

43 DataCoordinateSet, 

44 Dimension, 

45 DimensionConfig, 

46 DimensionElement, 

47 DimensionGroup, 

48 DimensionPacker, 

49 DimensionUniverse, 

50 NamedKeyDict, 

51 NamedValueSet, 

52 TimespanDatabaseRepresentation, 

53 YamlRepoImportBackend, 

54 ddl, 

55) 

56from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

57 

58DIMENSION_DATA_FILE = os.path.normpath( 

59 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

60) 

61 

62 

63def loadDimensionData() -> DataCoordinateSequence: 

64 """Load dimension data from an export file included in the code repository. 

65 

66 Returns 

67 ------- 

68 dataIds : `DataCoordinateSet` 

69 A set containing all data IDs in the export file. 

70 """ 

71 # Create an in-memory SQLite database and Registry just to import the YAML 

72 # data and retreive it as a set of DataCoordinate objects. 

73 config = RegistryConfig() 

74 config["db"] = "sqlite://" 

75 registry = _RegistryFactory(config).create_from_config() 

76 with open(DIMENSION_DATA_FILE) as stream: 

77 backend = YamlRepoImportBackend(stream, registry) 

78 backend.register() 

79 backend.load(datastore=None) 

80 dimensions = registry.dimensions.conform(["visit", "detector", "tract", "patch"]) 

81 return registry.queryDataIds(dimensions).expanded().toSequence() 

82 

83 

84class ConcreteTestDimensionPacker(DimensionPacker): 

85 """A concrete `DimensionPacker` for testing its base class implementations. 

86 

87 This class just returns the detector ID as-is. 

88 """ 

89 

90 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGroup): 

91 super().__init__(fixed, dimensions) 

92 self._n_detectors = fixed.records["instrument"].detector_max 

93 self._max_bits = (self._n_detectors - 1).bit_length() 

94 

95 @property 

96 def maxBits(self) -> int: 

97 # Docstring inherited from DimensionPacker.maxBits 

98 return self._max_bits 

99 

100 def _pack(self, dataId: DataCoordinate) -> int: 

101 # Docstring inherited from DimensionPacker._pack 

102 return dataId["detector"] 

103 

104 def unpack(self, packedId: int) -> DataCoordinate: 

105 # Docstring inherited from DimensionPacker.unpack 

106 return DataCoordinate.standardize( 

107 { 

108 "instrument": self.fixed["instrument"], 

109 "detector": packedId, 

110 }, 

111 dimensions=self._dimensions, 

112 ) 

113 

114 

115class DimensionTestCase(unittest.TestCase): 

116 """Tests for dimensions. 

117 

118 All tests here rely on the content of ``config/dimensions.yaml``, either 

119 to test that the definitions there are read in properly or just as generic 

120 data for testing various operations. 

121 """ 

122 

123 def setUp(self): 

124 self.universe = DimensionUniverse() 

125 

126 def checkGroupInvariants(self, group: DimensionGroup): 

127 elements = list(group.elements) 

128 for n, element_name in enumerate(elements): 

129 element = self.universe[element_name] 

130 # Ordered comparisons on graphs behave like sets. 

131 self.assertLessEqual(element.minimal_group, group) 

132 # Ordered comparisons on elements correspond to the ordering within 

133 # a DimensionUniverse (topological, with deterministic 

134 # tiebreakers). 

135 for other_name in elements[:n]: 

136 other = self.universe[other_name] 

137 self.assertLess(other, element) 

138 self.assertLessEqual(other, element) 

139 for other_name in elements[n + 1 :]: 

140 other = self.universe[other_name] 

141 self.assertGreater(other, element) 

142 self.assertGreaterEqual(other, element) 

143 if isinstance(element, Dimension): 

144 self.assertEqual(element.minimal_group.required, element.required) 

145 self.assertEqual(self.universe.conform(group.required), group) 

146 self.assertCountEqual( 

147 group.required, 

148 [ 

149 dimension_name 

150 for dimension_name in group.names 

151 if not any( 

152 dimension_name in self.universe[other_name].minimal_group.implied 

153 for other_name in group.elements 

154 ) 

155 ], 

156 ) 

157 self.assertCountEqual(group.implied, group.names - group.required) 

158 self.assertCountEqual(group.names, itertools.chain(group.required, group.implied)) 

159 # Check primary key traversal order: each element should follow any it 

160 # requires, and element that is implied by any other in the graph 

161 # follow at least one of those. 

162 seen: set[str] = set() 

163 for element_name in group.lookup_order: 

164 element = self.universe[element_name] 

165 with self.subTest(required=group.required, implied=group.implied, element=element): 

166 seen.add(element_name) 

167 self.assertLessEqual(element.minimal_group.required, seen) 

168 if element_name in group.implied: 

169 self.assertTrue(any(element_name in self.universe[s].implied for s in seen)) 

170 self.assertCountEqual(seen, group.elements) 

171 

172 def testConfigPresent(self): 

173 config = self.universe.dimensionConfig 

174 self.assertIsInstance(config, DimensionConfig) 

175 

176 def testCompatibility(self): 

177 # Simple check that should always be true. 

178 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

179 

180 # Create a universe like the default universe but with a different 

181 # version number. 

182 clone = self.universe.dimensionConfig.copy() 

183 clone["version"] = clone["version"] + 1_000_000 # High version number 

184 universe_clone = DimensionUniverse(config=clone) 

185 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm: 

186 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

187 self.assertIn("differing versions", "\n".join(cm.output)) 

188 

189 # Create completely incompatible universe. 

190 config = Config( 

191 { 

192 "version": 1, 

193 "namespace": "compat_test", 

194 "skypix": { 

195 "common": "htm7", 

196 "htm": { 

197 "class": "lsst.sphgeom.HtmPixelization", 

198 "max_level": 24, 

199 }, 

200 }, 

201 "elements": { 

202 "A": { 

203 "keys": [ 

204 { 

205 "name": "id", 

206 "type": "int", 

207 } 

208 ], 

209 "storage": { 

210 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

211 }, 

212 }, 

213 "B": { 

214 "keys": [ 

215 { 

216 "name": "id", 

217 "type": "int", 

218 } 

219 ], 

220 "storage": { 

221 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

222 }, 

223 }, 

224 }, 

225 "packers": {}, 

226 } 

227 ) 

228 universe2 = DimensionUniverse(config=config) 

229 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

230 

231 def testVersion(self): 

232 self.assertEqual(self.universe.namespace, "daf_butler") 

233 # Test was added starting at version 2. 

234 self.assertGreaterEqual(self.universe.version, 2) 

235 

236 def testConfigRead(self): 

237 self.assertEqual( 

238 set(self.universe.getStaticDimensions().names), 

239 { 

240 "instrument", 

241 "visit", 

242 "visit_system", 

243 "exposure", 

244 "detector", 

245 "physical_filter", 

246 "band", 

247 "subfilter", 

248 "skymap", 

249 "tract", 

250 "patch", 

251 } 

252 | {f"htm{level}" for level in range(25)} 

253 | {f"healpix{level}" for level in range(18)}, 

254 ) 

255 

256 def testGraphs(self): 

257 self.checkGroupInvariants(self.universe.empty.as_group()) 

258 for element in self.universe.getStaticElements(): 

259 self.checkGroupInvariants(element.minimal_group) 

260 

261 def testInstrumentDimensions(self): 

262 group = self.universe.conform(["exposure", "detector", "visit"]) 

263 self.assertCountEqual( 

264 group.names, 

265 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

266 ) 

267 self.assertCountEqual(group.required, ("instrument", "exposure", "detector", "visit")) 

268 self.assertCountEqual(group.implied, ("physical_filter", "band")) 

269 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition")) 

270 self.assertCountEqual(group.governors, {"instrument"}) 

271 

272 def testCalibrationDimensions(self): 

273 group = self.universe.conform(["physical_filter", "detector"]) 

274 self.assertCountEqual(group.names, ("instrument", "detector", "physical_filter", "band")) 

275 self.assertCountEqual(group.required, ("instrument", "detector", "physical_filter")) 

276 self.assertCountEqual(group.implied, ("band",)) 

277 self.assertCountEqual(group.elements, group.names) 

278 self.assertCountEqual(group.governors, {"instrument"}) 

279 

280 def testObservationDimensions(self): 

281 group = self.universe.conform(["exposure", "detector", "visit"]) 

282 self.assertCountEqual( 

283 group.names, 

284 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

285 ) 

286 self.assertCountEqual(group.required, ("instrument", "detector", "exposure", "visit")) 

287 self.assertCountEqual(group.implied, ("physical_filter", "band")) 

288 self.assertCountEqual(group.elements - group.names, ("visit_detector_region", "visit_definition")) 

289 self.assertCountEqual(group.spatial.names, ("observation_regions",)) 

290 self.assertCountEqual(group.temporal.names, ("observation_timespans",)) 

291 self.assertCountEqual(group.governors, {"instrument"}) 

292 self.assertEqual(group.spatial.names, {"observation_regions"}) 

293 self.assertEqual(group.temporal.names, {"observation_timespans"}) 

294 self.assertEqual(next(iter(group.spatial)).governor, self.universe["instrument"]) 

295 self.assertEqual(next(iter(group.temporal)).governor, self.universe["instrument"]) 

296 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"]) 

297 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"]) 

298 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"]) 

299 self.assertEqual( 

300 self.universe.get_elements_populated_by(self.universe["visit"]), 

301 NamedValueSet( 

302 { 

303 self.universe["visit"], 

304 self.universe["visit_definition"], 

305 self.universe["visit_system_membership"], 

306 self.universe["visit_detector_region"], 

307 } 

308 ), 

309 ) 

310 

311 def testSkyMapDimensions(self): 

312 group = self.universe.conform(["patch"]) 

313 self.assertEqual(group.names, {"skymap", "tract", "patch"}) 

314 self.assertEqual(group.required, {"skymap", "tract", "patch"}) 

315 self.assertEqual(group.implied, set()) 

316 self.assertEqual(group.elements, group.names) 

317 self.assertEqual(group.governors, {"skymap"}) 

318 self.assertEqual(group.spatial.names, {"skymap_regions"}) 

319 self.assertEqual(next(iter(group.spatial)).governor, self.universe["skymap"]) 

320 

321 def testSubsetCalculation(self): 

322 """Test that independent spatial and temporal options are computed 

323 correctly. 

324 """ 

325 group = self.universe.conform(["visit", "detector", "tract", "patch", "htm7", "exposure"]) 

326 self.assertCountEqual(group.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

327 self.assertCountEqual(group.temporal.names, ("observation_timespans",)) 

328 

329 def testSchemaGeneration(self): 

330 tableSpecs: NamedKeyDict[DimensionElement, ddl.TableSpec] = NamedKeyDict({}) 

331 for element in self.universe.getStaticElements(): 

332 if element.hasTable and element.viewOf is None: 

333 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

334 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

335 ) 

336 for element, tableSpec in tableSpecs.items(): 

337 for dep in element.required: 

338 with self.subTest(element=element.name, dep=dep.name): 

339 if dep != element: 

340 self.assertIn(dep.name, tableSpec.fields) 

341 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

342 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

343 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

344 self.assertFalse(tableSpec.fields[dep.name].nullable) 

345 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

346 else: 

347 self.assertIn(element.primaryKey.name, tableSpec.fields) 

348 self.assertEqual( 

349 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

350 ) 

351 self.assertEqual( 

352 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

353 ) 

354 self.assertEqual( 

355 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

356 ) 

357 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

358 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

359 for dep in element.implied: 

360 with self.subTest(element=element.name, dep=dep.name): 

361 self.assertIn(dep.name, tableSpec.fields) 

362 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

363 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

364 for foreignKey in tableSpec.foreignKeys: 

365 self.assertIn(foreignKey.table, tableSpecs) 

366 self.assertIn(foreignKey.table, element.dimensions) 

367 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

368 for source, target in zip(foreignKey.source, foreignKey.target, strict=True): 

369 self.assertIn(source, tableSpec.fields.names) 

370 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

371 self.assertEqual( 

372 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

373 ) 

374 self.assertEqual( 

375 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

376 ) 

377 self.assertEqual( 

378 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

379 ) 

380 

381 def testPickling(self): 

382 # Pickling and copying should always yield the exact same object within 

383 # a single process (cross-process is impossible to test here). 

384 universe1 = DimensionUniverse() 

385 universe2 = pickle.loads(pickle.dumps(universe1)) 

386 universe3 = copy.copy(universe1) 

387 universe4 = copy.deepcopy(universe1) 

388 self.assertIs(universe1, universe2) 

389 self.assertIs(universe1, universe3) 

390 self.assertIs(universe1, universe4) 

391 for element1 in universe1.getStaticElements(): 

392 element2 = pickle.loads(pickle.dumps(element1)) 

393 self.assertIs(element1, element2) 

394 group1 = element1.minimal_group 

395 group2 = pickle.loads(pickle.dumps(group1)) 

396 self.assertIs(group1, group2) 

397 

398 

399@dataclass 

400class SplitByStateFlags: 

401 """A struct that separates data IDs with different states but the same 

402 values. 

403 """ 

404 

405 minimal: DataCoordinateSequence | None = None 

406 """Data IDs that only contain values for required dimensions. 

407 

408 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

409 if ``minimal.dimensions.implied`` has no elements. 

410 `DataCoordinate.hasRecords()` will always return `False`. 

411 """ 

412 

413 complete: DataCoordinateSequence | None = None 

414 """Data IDs that contain values for all dimensions. 

415 

416 `DataCoordinateSequence.hasFull()` will always `True` and 

417 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

418 """ 

419 

420 expanded: DataCoordinateSequence | None = None 

421 """Data IDs that contain values for all dimensions as well as records. 

422 

423 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

424 always return `True` for this attribute. 

425 """ 

426 

427 def chain(self, n: int | None = None) -> Iterator[DataCoordinate]: 

428 """Iterate over the data IDs of different types. 

429 

430 Parameters 

431 ---------- 

432 n : `int`, optional 

433 If provided (`None` is default), iterate over only the ``nth`` 

434 data ID in each attribute. 

435 

436 Yields 

437 ------ 

438 dataId : `DataCoordinate` 

439 A data ID from one of the attributes in this struct. 

440 """ 

441 if n is None: 

442 s = slice(None, None) 

443 else: 

444 s = slice(n, n + 1) 

445 if self.minimal is not None: 

446 yield from self.minimal[s] 

447 if self.complete is not None: 

448 yield from self.complete[s] 

449 if self.expanded is not None: 

450 yield from self.expanded[s] 

451 

452 

453class DataCoordinateTestCase(unittest.TestCase): 

454 """Test `DataCoordinate`.""" 

455 

456 RANDOM_SEED = 10 

457 

458 @classmethod 

459 def setUpClass(cls): 

460 cls.allDataIds = loadDimensionData() 

461 

462 def setUp(self): 

463 self.rng = Random(self.RANDOM_SEED) 

464 

465 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None): 

466 """Select random data IDs from those loaded from test data. 

467 

468 Parameters 

469 ---------- 

470 n : `int` 

471 Number of data IDs to select. 

472 dataIds : `DataCoordinateSequence`, optional 

473 Data IDs to select from. Defaults to ``self.allDataIds``. 

474 

475 Returns 

476 ------- 

477 selected : `DataCoordinateSequence` 

478 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

479 """ 

480 if dataIds is None: 

481 dataIds = self.allDataIds 

482 return DataCoordinateSequence( 

483 self.rng.sample(dataIds, n), 

484 dimensions=dataIds.dimensions, 

485 hasFull=dataIds.hasFull(), 

486 hasRecords=dataIds.hasRecords(), 

487 check=False, 

488 ) 

489 

490 def randomDimensionSubset(self, n: int = 3, group: DimensionGroup | None = None) -> DimensionGroup: 

491 """Generate a random `DimensionGroup` that has a subset of the 

492 dimensions in a given one. 

493 

494 Parameters 

495 ---------- 

496 n : `int` 

497 Number of dimensions to select, before automatic expansion by 

498 `DimensionGroup`. 

499 group : `DimensionGroup`, optional 

500 Dimensions to select from. Defaults to 

501 ``self.allDataIds.dimensions``. 

502 

503 Returns 

504 ------- 

505 selected : `DimensionGroup` 

506 ``n`` or more dimensions randomly selected from ``group`` with 

507 replacement. 

508 """ 

509 if group is None: 

510 group = self.allDataIds.dimensions 

511 return group.universe.conform(self.rng.sample(list(group.names), max(n, len(group)))) 

512 

513 def splitByStateFlags( 

514 self, 

515 dataIds: DataCoordinateSequence | None = None, 

516 *, 

517 expanded: bool = True, 

518 complete: bool = True, 

519 minimal: bool = True, 

520 ) -> SplitByStateFlags: 

521 """Given a sequence of data IDs, generate new equivalent sequences 

522 containing less information. 

523 

524 Parameters 

525 ---------- 

526 dataIds : `DataCoordinateSequence`, optional. 

527 Data IDs to start from. Defaults to ``self.allDataIds``. 

528 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

529 `True`. 

530 expanded : `bool`, optional 

531 If `True` (default) include the original data IDs that contain all 

532 information in the result. 

533 complete : `bool`, optional 

534 If `True` (default) include data IDs for which ``hasFull()`` 

535 returns `True` but ``hasRecords()`` does not. 

536 minimal : `bool`, optional 

537 If `True` (default) include data IDS that only contain values for 

538 required dimensions, for which ``hasFull()`` may not return `True`. 

539 

540 Returns 

541 ------- 

542 split : `SplitByStateFlags` 

543 A dataclass holding the indicated data IDs in attributes that 

544 correspond to the boolean keyword arguments. 

545 """ 

546 if dataIds is None: 

547 dataIds = self.allDataIds 

548 assert dataIds.hasFull() and dataIds.hasRecords() 

549 result = SplitByStateFlags(expanded=dataIds) 

550 if complete: 

551 result.complete = DataCoordinateSequence( 

552 [ 

553 DataCoordinate.standardize(e.mapping, dimensions=dataIds.dimensions) 

554 for e in result.expanded 

555 ], 

556 dimensions=dataIds.dimensions, 

557 ) 

558 self.assertTrue(result.complete.hasFull()) 

559 self.assertFalse(result.complete.hasRecords()) 

560 if minimal: 

561 result.minimal = DataCoordinateSequence( 

562 [ 

563 DataCoordinate.standardize(e.required, dimensions=dataIds.dimensions) 

564 for e in result.expanded 

565 ], 

566 dimensions=dataIds.dimensions, 

567 ) 

568 self.assertEqual(result.minimal.hasFull(), not dataIds.dimensions.implied) 

569 self.assertFalse(result.minimal.hasRecords()) 

570 if not expanded: 

571 result.expanded = None 

572 return result 

573 

574 def testMappingViews(self): 

575 """Test that the ``mapping`` and ``required`` attributes in 

576 `DataCoordinate` are self-consistent and consistent with the 

577 ``dimensions`` property. 

578 """ 

579 for _ in range(5): 

580 dimensions = self.randomDimensionSubset() 

581 dataIds = self.randomDataIds(n=1).subset(dimensions) 

582 split = self.splitByStateFlags(dataIds) 

583 for dataId in split.chain(): 

584 with self.subTest(dataId=dataId): 

585 self.assertEqual(dataId.required.keys(), dataId.dimensions.required) 

586 self.assertEqual( 

587 list(dataId.required.values()), [dataId[d] for d in dataId.dimensions.required] 

588 ) 

589 self.assertEqual( 

590 list(dataId.required_values), [dataId[d] for d in dataId.dimensions.required] 

591 ) 

592 self.assertEqual(dataId.required.keys(), dataId.dimensions.required) 

593 for dataId in itertools.chain(split.complete, split.expanded): 

594 with self.subTest(dataId=dataId): 

595 self.assertTrue(dataId.hasFull()) 

596 self.assertEqual(dataId.dimensions.names, dataId.mapping.keys()) 

597 self.assertEqual( 

598 list(dataId.mapping.values()), [dataId[k] for k in dataId.mapping.keys()] 

599 ) 

600 

601 def test_pickle(self): 

602 for _ in range(5): 

603 dimensions = self.randomDimensionSubset() 

604 dataIds = self.randomDataIds(n=1).subset(dimensions) 

605 split = self.splitByStateFlags(dataIds) 

606 for data_id in split.chain(): 

607 s = pickle.dumps(data_id) 

608 read_data_id: DataCoordinate = pickle.loads(s) 

609 self.assertEqual(data_id, read_data_id) 

610 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

611 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

612 if data_id.hasFull(): 

613 self.assertEqual(data_id.mapping, read_data_id.mapping) 

614 if data_id.hasRecords(): 

615 for element_name in data_id.dimensions.elements: 

616 self.assertEqual( 

617 data_id.records[element_name], read_data_id.records[element_name] 

618 ) 

619 

620 def test_record_attributes(self): 

621 """Test that dimension records are available as attributes on expanded 

622 data coordinates. 

623 """ 

624 for _ in range(5): 

625 dimensions = self.randomDimensionSubset() 

626 dataIds = self.randomDataIds(n=1).subset(dimensions) 

627 split = self.splitByStateFlags(dataIds) 

628 for data_id in split.expanded: 

629 for element_name in data_id.dimensions.elements: 

630 self.assertIs(getattr(data_id, element_name), data_id.records[element_name]) 

631 self.assertIn(element_name, dir(data_id)) 

632 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

633 data_id.not_a_dimension_name 

634 for data_id in itertools.chain(split.minimal, split.complete): 

635 for element_name in data_id.dimensions.elements: 

636 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

637 getattr(data_id, element_name) 

638 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

639 data_id.not_a_dimension_name 

640 

641 def testEquality(self): 

642 """Test that different `DataCoordinate` instances with different state 

643 flags can be compared with each other and other mappings. 

644 """ 

645 dataIds = self.randomDataIds(n=2) 

646 split = self.splitByStateFlags(dataIds) 

647 # Iterate over all combinations of different states of DataCoordinate, 

648 # with the same underlying data ID values. 

649 for a0, b0 in itertools.combinations(split.chain(0), 2): 

650 self.assertEqual(a0, b0) 

651 # Same thing, for a different data ID value. 

652 for a1, b1 in itertools.combinations(split.chain(1), 2): 

653 self.assertEqual(a1, b1) 

654 # Iterate over all combinations of different states of DataCoordinate, 

655 # with different underlying data ID values. 

656 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

657 self.assertNotEqual(a0, b1) 

658 self.assertNotEqual(a1, b0) 

659 

660 def testStandardize(self): 

661 """Test constructing a DataCoordinate from many different kinds of 

662 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

663 """ 

664 for _ in range(5): 

665 dimensions = self.randomDimensionSubset() 

666 dataIds = self.randomDataIds(n=1).subset(dimensions) 

667 split = self.splitByStateFlags(dataIds) 

668 for dataId in split.chain(): 

669 # Passing in any kind of DataCoordinate alone just returns 

670 # that object. 

671 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

672 # Same if we also explicitly pass the dimensions we want. 

673 self.assertIs(dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions)) 

674 # Same if we pass the dimensions and some irrelevant 

675 # kwargs. 

676 self.assertIs( 

677 dataId, DataCoordinate.standardize(dataId, dimensions=dataId.dimensions, htm7=12) 

678 ) 

679 # Test constructing a new data ID from this one with a 

680 # subset of the dimensions. 

681 # This is not possible for some combinations of 

682 # dimensions if hasFull is False (see 

683 # `DataCoordinate.subset` docs). 

684 newDimensions = self.randomDimensionSubset(n=1, group=dataId.dimensions) 

685 if dataId.hasFull() or dataId.dimensions.required >= newDimensions.required: 

686 newDataIds = [ 

687 dataId.subset(newDimensions), 

688 DataCoordinate.standardize(dataId, dimensions=newDimensions), 

689 DataCoordinate.standardize(dataId, dimensions=newDimensions, htm7=12), 

690 ] 

691 for newDataId in newDataIds: 

692 with self.subTest(newDataId=newDataId, type=type(dataId)): 

693 commonKeys = dataId.dimensions.required & newDataId.dimensions.required 

694 self.assertTrue(commonKeys) 

695 self.assertEqual( 

696 [newDataId[k] for k in commonKeys], 

697 [dataId[k] for k in commonKeys], 

698 ) 

699 # This should never "downgrade" from 

700 # Complete to Minimal or Expanded to Complete. 

701 if dataId.hasRecords(): 

702 self.assertTrue(newDataId.hasRecords()) 

703 if dataId.hasFull(): 

704 self.assertTrue(newDataId.hasFull()) 

705 # Start from a complete data ID, and pass its values in via several 

706 # different ways that should be equivalent. 

707 for dataId in split.complete: 

708 # Split the keys (dimension names) into two random subsets, so 

709 # we can pass some as kwargs below. 

710 keys1 = set(self.rng.sample(list(dataId.dimensions.names), len(dataId.dimensions) // 2)) 

711 keys2 = dataId.dimensions.names - keys1 

712 newCompleteDataIds = [ 

713 DataCoordinate.standardize(dataId.mapping, universe=dataId.universe), 

714 DataCoordinate.standardize(dataId.mapping, dimensions=dataId.dimensions), 

715 DataCoordinate.standardize( 

716 DataCoordinate.make_empty(dataId.dimensions.universe), **dataId.mapping 

717 ), 

718 DataCoordinate.standardize( 

719 DataCoordinate.make_empty(dataId.dimensions.universe), 

720 dimensions=dataId.dimensions, 

721 **dataId.mapping, 

722 ), 

723 DataCoordinate.standardize(**dataId.mapping, universe=dataId.universe), 

724 DataCoordinate.standardize(dimensions=dataId.dimensions, **dataId.mapping), 

725 DataCoordinate.standardize( 

726 {k: dataId[k] for k in keys1}, 

727 universe=dataId.universe, 

728 **{k: dataId[k] for k in keys2}, 

729 ), 

730 DataCoordinate.standardize( 

731 {k: dataId[k] for k in keys1}, 

732 dimensions=dataId.dimensions, 

733 **{k: dataId[k] for k in keys2}, 

734 ), 

735 ] 

736 for newDataId in newCompleteDataIds: 

737 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

738 self.assertEqual(dataId, newDataId) 

739 self.assertTrue(newDataId.hasFull()) 

740 

741 def testUnion(self): 

742 """Test `DataCoordinate.union`.""" 

743 # Make test groups to combine; mostly random, but with a few explicit 

744 # cases to make sure certain edge cases are covered. 

745 groups = [self.randomDimensionSubset(n=2) for i in range(2)] 

746 groups.append(self.allDataIds.universe["visit"].minimal_group) 

747 groups.append(self.allDataIds.universe["detector"].minimal_group) 

748 groups.append(self.allDataIds.universe["physical_filter"].minimal_group) 

749 groups.append(self.allDataIds.universe["band"].minimal_group) 

750 # Iterate over all combinations, including the same graph with itself. 

751 for group1, group2 in itertools.product(groups, repeat=2): 

752 parentDataIds = self.randomDataIds(n=1) 

753 split1 = self.splitByStateFlags(parentDataIds.subset(group1)) 

754 split2 = self.splitByStateFlags(parentDataIds.subset(group2)) 

755 (parentDataId,) = parentDataIds 

756 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

757 unioned = lhs.union(rhs) 

758 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

759 self.assertEqual(unioned.dimensions, group1.union(group2)) 

760 self.assertEqual(unioned, parentDataId.subset(unioned.dimensions)) 

761 if unioned.hasFull(): 

762 self.assertEqual(unioned.subset(lhs.dimensions), lhs) 

763 self.assertEqual(unioned.subset(rhs.dimensions), rhs) 

764 if lhs.hasFull() and rhs.hasFull(): 

765 self.assertTrue(unioned.hasFull()) 

766 if lhs.dimensions >= unioned.dimensions and lhs.hasFull(): 

767 self.assertTrue(unioned.hasFull()) 

768 if lhs.hasRecords(): 

769 self.assertTrue(unioned.hasRecords()) 

770 if rhs.dimensions >= unioned.dimensions and rhs.hasFull(): 

771 self.assertTrue(unioned.hasFull()) 

772 if rhs.hasRecords(): 

773 self.assertTrue(unioned.hasRecords()) 

774 if lhs.dimensions.required | rhs.dimensions.required >= unioned.dimensions.names: 

775 self.assertTrue(unioned.hasFull()) 

776 if ( 

777 lhs.hasRecords() 

778 and rhs.hasRecords() 

779 and lhs.dimensions.elements | rhs.dimensions.elements >= unioned.dimensions.elements 

780 ): 

781 self.assertTrue(unioned.hasRecords()) 

782 

783 def testRegions(self): 

784 """Test that data IDs for a few known dimensions have the expected 

785 regions. 

786 """ 

787 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])): 

788 self.assertIsNotNone(dataId.region) 

789 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"}) 

790 self.assertEqual(dataId.region, dataId.records["visit"].region) 

791 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit", "detector"])): 

792 self.assertIsNotNone(dataId.region) 

793 self.assertEqual(dataId.dimensions.spatial.names, {"observation_regions"}) 

794 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

795 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["tract"])): 

796 self.assertIsNotNone(dataId.region) 

797 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"}) 

798 self.assertEqual(dataId.region, dataId.records["tract"].region) 

799 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])): 

800 self.assertIsNotNone(dataId.region) 

801 self.assertEqual(dataId.dimensions.spatial.names, {"skymap_regions"}) 

802 self.assertEqual(dataId.region, dataId.records["patch"].region) 

803 for data_id in self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["visit", "tract"])): 

804 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

805 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

806 

807 def testTimespans(self): 

808 """Test that data IDs for a few known dimensions have the expected 

809 timespans. 

810 """ 

811 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["visit"])): 

812 self.assertIsNotNone(dataId.timespan) 

813 self.assertEqual(dataId.dimensions.temporal.names, {"observation_timespans"}) 

814 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

815 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

816 # Also test the case for non-temporal DataIds. 

817 for dataId in self.randomDataIds(n=4).subset(self.allDataIds.universe.conform(["patch"])): 

818 self.assertIsNone(dataId.timespan) 

819 

820 def testIterableStatusFlags(self): 

821 """Test that DataCoordinateSet and DataCoordinateSequence compute 

822 their hasFull and hasRecords flags correctly from their elements. 

823 """ 

824 dataIds = self.randomDataIds(n=10) 

825 split = self.splitByStateFlags(dataIds) 

826 for cls in (DataCoordinateSet, DataCoordinateSequence): 

827 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasFull()) 

828 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasFull()) 

829 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=True).hasRecords()) 

830 self.assertTrue(cls(split.expanded, dimensions=dataIds.dimensions, check=False).hasRecords()) 

831 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasFull()) 

832 self.assertTrue(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasFull()) 

833 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=True).hasRecords()) 

834 self.assertFalse(cls(split.complete, dimensions=dataIds.dimensions, check=False).hasRecords()) 

835 with self.assertRaises(ValueError): 

836 cls(split.complete, dimensions=dataIds.dimensions, hasRecords=True, check=True) 

837 self.assertEqual( 

838 cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasFull(), 

839 not dataIds.dimensions.implied, 

840 ) 

841 self.assertEqual( 

842 cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasFull(), 

843 not dataIds.dimensions.implied, 

844 ) 

845 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=True).hasRecords()) 

846 self.assertFalse(cls(split.minimal, dimensions=dataIds.dimensions, check=False).hasRecords()) 

847 with self.assertRaises(ValueError): 

848 cls(split.minimal, dimensions=dataIds.dimensions, hasRecords=True, check=True) 

849 if dataIds.dimensions.implied: 

850 with self.assertRaises(ValueError): 

851 cls(split.minimal, dimensions=dataIds.dimensions, hasFull=True, check=True) 

852 

853 def testSetOperations(self): 

854 """Test for self-consistency across DataCoordinateSet's operations.""" 

855 c = self.randomDataIds(n=10).toSet() 

856 a = self.randomDataIds(n=20).toSet() | c 

857 b = self.randomDataIds(n=20).toSet() | c 

858 # Make sure we don't have a particularly unlucky random seed, since 

859 # that would make a lot of this test uninteresting. 

860 self.assertNotEqual(a, b) 

861 self.assertGreater(len(a), 0) 

862 self.assertGreater(len(b), 0) 

863 # The rest of the tests should not depend on the random seed. 

864 self.assertEqual(a, a) 

865 self.assertNotEqual(a, a.toSequence()) 

866 self.assertEqual(a, a.toSequence().toSet()) 

867 self.assertEqual(a, a.toSequence().toSet()) 

868 self.assertEqual(b, b) 

869 self.assertNotEqual(b, b.toSequence()) 

870 self.assertEqual(b, b.toSequence().toSet()) 

871 self.assertEqual(a & b, a.intersection(b)) 

872 self.assertLessEqual(a & b, a) 

873 self.assertLessEqual(a & b, b) 

874 self.assertEqual(a | b, a.union(b)) 

875 self.assertGreaterEqual(a | b, a) 

876 self.assertGreaterEqual(a | b, b) 

877 self.assertEqual(a - b, a.difference(b)) 

878 self.assertLessEqual(a - b, a) 

879 self.assertLessEqual(b - a, b) 

880 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

881 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

882 

883 def testPackers(self): 

884 (instrument_data_id,) = self.allDataIds.subset( 

885 self.allDataIds.universe.conform(["instrument"]) 

886 ).toSet() 

887 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.conform(["detector"])) 

888 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.dimensions) 

889 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

890 self.assertEqual(packed_id, detector_data_id["detector"]) 

891 self.assertEqual(max_bits, packer.maxBits) 

892 self.assertEqual( 

893 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

894 ) 

895 self.assertEqual(packer.pack(detector_data_id), packed_id) 

896 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

897 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

898 

899 

900if __name__ == "__main__": 

901 unittest.main()