Coverage for tests/test_dimensions.py: 10%

441 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-28 10:10 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import math 

25import os 

26import pickle 

27import unittest 

28from collections.abc import Iterator 

29from dataclasses import dataclass 

30from random import Random 

31 

32import lsst.sphgeom 

33from lsst.daf.butler import ( 

34 Config, 

35 DataCoordinate, 

36 DataCoordinateSequence, 

37 DataCoordinateSet, 

38 Dimension, 

39 DimensionConfig, 

40 DimensionGraph, 

41 DimensionPacker, 

42 DimensionUniverse, 

43 NamedKeyDict, 

44 NamedValueSet, 

45 Registry, 

46 TimespanDatabaseRepresentation, 

47 YamlRepoImportBackend, 

48) 

49from lsst.daf.butler.registry import RegistryConfig 

50 

51DIMENSION_DATA_FILE = os.path.normpath( 

52 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

53) 

54 

55 

56def loadDimensionData() -> DataCoordinateSequence: 

57 """Load dimension data from an export file included in the code repository. 

58 

59 Returns 

60 ------- 

61 dataIds : `DataCoordinateSet` 

62 A set containing all data IDs in the export file. 

63 """ 

64 # Create an in-memory SQLite database and Registry just to import the YAML 

65 # data and retreive it as a set of DataCoordinate objects. 

66 config = RegistryConfig() 

67 config["db"] = "sqlite://" 

68 registry = Registry.createFromConfig(config) 

69 with open(DIMENSION_DATA_FILE) as stream: 

70 backend = YamlRepoImportBackend(stream, registry) 

71 backend.register() 

72 backend.load(datastore=None) 

73 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

74 return registry.queryDataIds(dimensions).expanded().toSequence() 

75 

76 

77class ConcreteTestDimensionPacker(DimensionPacker): 

78 """A concrete `DimensionPacker` for testing its base class implementations. 

79 

80 This class just returns the detector ID as-is. 

81 """ 

82 

83 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph): 

84 super().__init__(fixed, dimensions) 

85 self._n_detectors = fixed.records["instrument"].detector_max 

86 self._max_bits = (self._n_detectors - 1).bit_length() 

87 

88 @property 

89 def maxBits(self) -> int: 

90 # Docstring inherited from DimensionPacker.maxBits 

91 return self._max_bits 

92 

93 def _pack(self, dataId: DataCoordinate) -> int: 

94 # Docstring inherited from DimensionPacker._pack 

95 return dataId["detector"] 

96 

97 def unpack(self, packedId: int) -> DataCoordinate: 

98 # Docstring inherited from DimensionPacker.unpack 

99 return DataCoordinate.standardize( 

100 { 

101 "instrument": self.fixed["instrument"], 

102 "detector": packedId, 

103 }, 

104 graph=self.dimensions, 

105 ) 

106 

107 

108class DimensionTestCase(unittest.TestCase): 

109 """Tests for dimensions. 

110 

111 All tests here rely on the content of ``config/dimensions.yaml``, either 

112 to test that the definitions there are read in properly or just as generic 

113 data for testing various operations. 

114 """ 

115 

116 def setUp(self): 

117 self.universe = DimensionUniverse() 

118 

119 def checkGraphInvariants(self, graph): 

120 elements = list(graph.elements) 

121 for n, element in enumerate(elements): 

122 # Ordered comparisons on graphs behave like sets. 

123 self.assertLessEqual(element.graph, graph) 

124 # Ordered comparisons on elements correspond to the ordering within 

125 # a DimensionUniverse (topological, with deterministic 

126 # tiebreakers). 

127 for other in elements[:n]: 

128 self.assertLess(other, element) 

129 self.assertLessEqual(other, element) 

130 for other in elements[n + 1 :]: 

131 self.assertGreater(other, element) 

132 self.assertGreaterEqual(other, element) 

133 if isinstance(element, Dimension): 

134 self.assertEqual(element.graph.required, element.required) 

135 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

136 self.assertCountEqual( 

137 graph.required, 

138 [ 

139 dimension 

140 for dimension in graph.dimensions 

141 if not any(dimension in other.graph.implied for other in graph.elements) 

142 ], 

143 ) 

144 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

145 self.assertCountEqual( 

146 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

147 ) 

148 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

149 # Check primary key traversal order: each element should follow any it 

150 # requires, and element that is implied by any other in the graph 

151 # follow at least one of those. 

152 seen = NamedValueSet() 

153 for element in graph.primaryKeyTraversalOrder: 

154 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

155 seen.add(element) 

156 self.assertLessEqual(element.graph.required, seen) 

157 if element in graph.implied: 

158 self.assertTrue(any(element in s.implied for s in seen)) 

159 self.assertCountEqual(seen, graph.elements) 

160 

161 def testConfigPresent(self): 

162 config = self.universe.dimensionConfig 

163 self.assertIsInstance(config, DimensionConfig) 

164 

165 def testCompatibility(self): 

166 # Simple check that should always be true. 

167 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

168 

169 # Create a universe like the default universe but with a different 

170 # version number. 

171 clone = self.universe.dimensionConfig.copy() 

172 clone["version"] = clone["version"] + 1_000_000 # High version number 

173 universe_clone = DimensionUniverse(config=clone) 

174 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm: 

175 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

176 self.assertIn("differing versions", "\n".join(cm.output)) 

177 

178 # Create completely incompatible universe. 

179 config = Config( 

180 { 

181 "version": 1, 

182 "namespace": "compat_test", 

183 "skypix": { 

184 "common": "htm7", 

185 "htm": { 

186 "class": "lsst.sphgeom.HtmPixelization", 

187 "max_level": 24, 

188 }, 

189 }, 

190 "elements": { 

191 "A": { 

192 "keys": [ 

193 { 

194 "name": "id", 

195 "type": "int", 

196 } 

197 ], 

198 "storage": { 

199 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

200 }, 

201 }, 

202 "B": { 

203 "keys": [ 

204 { 

205 "name": "id", 

206 "type": "int", 

207 } 

208 ], 

209 "storage": { 

210 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

211 }, 

212 }, 

213 }, 

214 "packers": {}, 

215 } 

216 ) 

217 universe2 = DimensionUniverse(config=config) 

218 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

219 

220 def testVersion(self): 

221 self.assertEqual(self.universe.namespace, "daf_butler") 

222 # Test was added starting at version 2. 

223 self.assertGreaterEqual(self.universe.version, 2) 

224 

225 def testConfigRead(self): 

226 self.assertEqual( 

227 set(self.universe.getStaticDimensions().names), 

228 { 

229 "instrument", 

230 "visit", 

231 "visit_system", 

232 "exposure", 

233 "detector", 

234 "physical_filter", 

235 "band", 

236 "subfilter", 

237 "skymap", 

238 "tract", 

239 "patch", 

240 } 

241 | {f"htm{level}" for level in range(25)} 

242 | {f"healpix{level}" for level in range(18)}, 

243 ) 

244 

245 def testGraphs(self): 

246 self.checkGraphInvariants(self.universe.empty) 

247 for element in self.universe.getStaticElements(): 

248 self.checkGraphInvariants(element.graph) 

249 

250 def testInstrumentDimensions(self): 

251 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

252 self.assertCountEqual( 

253 graph.dimensions.names, 

254 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

255 ) 

256 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

257 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

258 self.assertCountEqual( 

259 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

260 ) 

261 self.assertCountEqual(graph.governors.names, {"instrument"}) 

262 

263 def testCalibrationDimensions(self): 

264 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

265 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

266 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

267 self.assertCountEqual(graph.implied.names, ("band",)) 

268 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

269 self.assertCountEqual(graph.governors.names, {"instrument"}) 

270 

271 def testObservationDimensions(self): 

272 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

273 self.assertCountEqual( 

274 graph.dimensions.names, 

275 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

276 ) 

277 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

278 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

279 self.assertCountEqual( 

280 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

281 ) 

282 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

283 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

284 self.assertCountEqual(graph.governors.names, {"instrument"}) 

285 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

286 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

287 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

288 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

289 

290 def testSkyMapDimensions(self): 

291 graph = DimensionGraph(self.universe, names=("patch",)) 

292 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

293 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

294 self.assertCountEqual(graph.implied.names, ()) 

295 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

296 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

297 self.assertCountEqual(graph.governors.names, {"skymap"}) 

298 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

299 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

300 

301 def testSubsetCalculation(self): 

302 """Test that independent spatial and temporal options are computed 

303 correctly. 

304 """ 

305 graph = DimensionGraph( 

306 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

307 ) 

308 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

309 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

310 

311 def testSchemaGeneration(self): 

312 tableSpecs = NamedKeyDict({}) 

313 for element in self.universe.getStaticElements(): 

314 if element.hasTable and element.viewOf is None: 

315 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

316 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

317 ) 

318 for element, tableSpec in tableSpecs.items(): 

319 for dep in element.required: 

320 with self.subTest(element=element.name, dep=dep.name): 

321 if dep != element: 

322 self.assertIn(dep.name, tableSpec.fields) 

323 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

324 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

325 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

326 self.assertFalse(tableSpec.fields[dep.name].nullable) 

327 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

328 else: 

329 self.assertIn(element.primaryKey.name, tableSpec.fields) 

330 self.assertEqual( 

331 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

332 ) 

333 self.assertEqual( 

334 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

335 ) 

336 self.assertEqual( 

337 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

338 ) 

339 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

340 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

341 for dep in element.implied: 

342 with self.subTest(element=element.name, dep=dep.name): 

343 self.assertIn(dep.name, tableSpec.fields) 

344 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

345 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

346 for foreignKey in tableSpec.foreignKeys: 

347 self.assertIn(foreignKey.table, tableSpecs) 

348 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

349 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

350 for source, target in zip(foreignKey.source, foreignKey.target): 

351 self.assertIn(source, tableSpec.fields.names) 

352 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

353 self.assertEqual( 

354 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

355 ) 

356 self.assertEqual( 

357 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

358 ) 

359 self.assertEqual( 

360 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

361 ) 

362 

363 def testPickling(self): 

364 # Pickling and copying should always yield the exact same object within 

365 # a single process (cross-process is impossible to test here). 

366 universe1 = DimensionUniverse() 

367 universe2 = pickle.loads(pickle.dumps(universe1)) 

368 universe3 = copy.copy(universe1) 

369 universe4 = copy.deepcopy(universe1) 

370 self.assertIs(universe1, universe2) 

371 self.assertIs(universe1, universe3) 

372 self.assertIs(universe1, universe4) 

373 for element1 in universe1.getStaticElements(): 

374 element2 = pickle.loads(pickle.dumps(element1)) 

375 self.assertIs(element1, element2) 

376 graph1 = element1.graph 

377 graph2 = pickle.loads(pickle.dumps(graph1)) 

378 self.assertIs(graph1, graph2) 

379 

380 

381@dataclass 

382class SplitByStateFlags: 

383 """A struct that separates data IDs with different states but the same 

384 values. 

385 """ 

386 

387 minimal: DataCoordinateSequence | None = None 

388 """Data IDs that only contain values for required dimensions. 

389 

390 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

391 if ``minimal.graph.implied`` has no elements. 

392 `DataCoordinate.hasRecords()` will always return `False`. 

393 """ 

394 

395 complete: DataCoordinateSequence | None = None 

396 """Data IDs that contain values for all dimensions. 

397 

398 `DataCoordinateSequence.hasFull()` will always `True` and 

399 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

400 """ 

401 

402 expanded: DataCoordinateSequence | None = None 

403 """Data IDs that contain values for all dimensions as well as records. 

404 

405 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

406 always return `True` for this attribute. 

407 """ 

408 

409 def chain(self, n: int | None = None) -> Iterator: 

410 """Iterate over the data IDs of different types. 

411 

412 Parameters 

413 ---------- 

414 n : `int`, optional 

415 If provided (`None` is default), iterate over only the ``nth`` 

416 data ID in each attribute. 

417 

418 Yields 

419 ------ 

420 dataId : `DataCoordinate` 

421 A data ID from one of the attributes in this struct. 

422 """ 

423 if n is None: 

424 s = slice(None, None) 

425 else: 

426 s = slice(n, n + 1) 

427 if self.minimal is not None: 

428 yield from self.minimal[s] 

429 if self.complete is not None: 

430 yield from self.complete[s] 

431 if self.expanded is not None: 

432 yield from self.expanded[s] 

433 

434 

435class DataCoordinateTestCase(unittest.TestCase): 

436 """Test `DataCoordinate`.""" 

437 

438 RANDOM_SEED = 10 

439 

440 @classmethod 

441 def setUpClass(cls): 

442 cls.allDataIds = loadDimensionData() 

443 

444 def setUp(self): 

445 self.rng = Random(self.RANDOM_SEED) 

446 

447 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None): 

448 """Select random data IDs from those loaded from test data. 

449 

450 Parameters 

451 ---------- 

452 n : `int` 

453 Number of data IDs to select. 

454 dataIds : `DataCoordinateSequence`, optional 

455 Data IDs to select from. Defaults to ``self.allDataIds``. 

456 

457 Returns 

458 ------- 

459 selected : `DataCoordinateSequence` 

460 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

461 """ 

462 if dataIds is None: 

463 dataIds = self.allDataIds 

464 return DataCoordinateSequence( 

465 self.rng.sample(dataIds, n), 

466 graph=dataIds.graph, 

467 hasFull=dataIds.hasFull(), 

468 hasRecords=dataIds.hasRecords(), 

469 check=False, 

470 ) 

471 

472 def randomDimensionSubset(self, n: int = 3, graph: DimensionGraph | None = None) -> DimensionGraph: 

473 """Generate a random `DimensionGraph` that has a subset of the 

474 dimensions in a given one. 

475 

476 Parameters 

477 ---------- 

478 n : `int` 

479 Number of dimensions to select, before automatic expansion by 

480 `DimensionGraph`. 

481 dataIds : `DimensionGraph`, optional 

482 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

483 

484 Returns 

485 ------- 

486 selected : `DimensionGraph` 

487 ``n`` or more dimensions randomly selected from ``graph`` with 

488 replacement. 

489 """ 

490 if graph is None: 

491 graph = self.allDataIds.graph 

492 return DimensionGraph( 

493 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

494 ) 

495 

496 def splitByStateFlags( 

497 self, 

498 dataIds: DataCoordinateSequence | None = None, 

499 *, 

500 expanded: bool = True, 

501 complete: bool = True, 

502 minimal: bool = True, 

503 ) -> SplitByStateFlags: 

504 """Given a sequence of data IDs, generate new equivalent sequences 

505 containing less information. 

506 

507 Parameters 

508 ---------- 

509 dataIds : `DataCoordinateSequence`, optional. 

510 Data IDs to start from. Defaults to ``self.allDataIds``. 

511 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

512 `True`. 

513 expanded : `bool`, optional 

514 If `True` (default) include the original data IDs that contain all 

515 information in the result. 

516 complete : `bool`, optional 

517 If `True` (default) include data IDs for which ``hasFull()`` 

518 returns `True` but ``hasRecords()`` does not. 

519 minimal : `bool`, optional 

520 If `True` (default) include data IDS that only contain values for 

521 required dimensions, for which ``hasFull()`` may not return `True`. 

522 

523 Returns 

524 ------- 

525 split : `SplitByStateFlags` 

526 A dataclass holding the indicated data IDs in attributes that 

527 correspond to the boolean keyword arguments. 

528 """ 

529 if dataIds is None: 

530 dataIds = self.allDataIds 

531 assert dataIds.hasFull() and dataIds.hasRecords() 

532 result = SplitByStateFlags(expanded=dataIds) 

533 if complete: 

534 result.complete = DataCoordinateSequence( 

535 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

536 graph=dataIds.graph, 

537 ) 

538 self.assertTrue(result.complete.hasFull()) 

539 self.assertFalse(result.complete.hasRecords()) 

540 if minimal: 

541 result.minimal = DataCoordinateSequence( 

542 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

543 graph=dataIds.graph, 

544 ) 

545 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

546 self.assertFalse(result.minimal.hasRecords()) 

547 if not expanded: 

548 result.expanded = None 

549 return result 

550 

551 def testMappingInterface(self): 

552 """Test that the mapping interface in `DataCoordinate` and (when 

553 applicable) its ``full`` property are self-consistent and consistent 

554 with the ``graph`` property. 

555 """ 

556 for n in range(5): 

557 dimensions = self.randomDimensionSubset() 

558 dataIds = self.randomDataIds(n=1).subset(dimensions) 

559 split = self.splitByStateFlags(dataIds) 

560 for dataId in split.chain(): 

561 with self.subTest(dataId=dataId): 

562 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()]) 

563 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()]) 

564 self.assertEqual(dataId.keys(), dataId.graph.required) 

565 for dataId in itertools.chain(split.complete, split.expanded): 

566 with self.subTest(dataId=dataId): 

567 self.assertTrue(dataId.hasFull()) 

568 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

569 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

570 

571 def test_pickle(self): 

572 for n in range(5): 

573 dimensions = self.randomDimensionSubset() 

574 dataIds = self.randomDataIds(n=1).subset(dimensions) 

575 split = self.splitByStateFlags(dataIds) 

576 for data_id in split.chain(): 

577 s = pickle.dumps(data_id) 

578 read_data_id = pickle.loads(s) 

579 self.assertEqual(data_id, read_data_id) 

580 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

581 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

582 if data_id.hasFull(): 

583 self.assertEqual(data_id.full, read_data_id.full) 

584 if data_id.hasRecords(): 

585 self.assertEqual(data_id.records, read_data_id.records) 

586 

587 def test_record_attributes(self): 

588 """Test that dimension records are available as attributes on expanded 

589 data coordinates. 

590 """ 

591 for n in range(5): 

592 dimensions = self.randomDimensionSubset() 

593 dataIds = self.randomDataIds(n=1).subset(dimensions) 

594 split = self.splitByStateFlags(dataIds) 

595 for data_id in split.expanded: 

596 for element in data_id.graph.elements: 

597 self.assertIs(getattr(data_id, element.name), data_id.records[element.name]) 

598 self.assertIn(element.name, dir(data_id)) 

599 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

600 getattr(data_id, "not_a_dimension_name") 

601 for data_id in itertools.chain(split.minimal, split.complete): 

602 for element in data_id.graph.elements: 

603 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

604 getattr(data_id, element.name) 

605 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

606 getattr(data_id, "not_a_dimension_name") 

607 

608 def testEquality(self): 

609 """Test that different `DataCoordinate` instances with different state 

610 flags can be compared with each other and other mappings. 

611 """ 

612 dataIds = self.randomDataIds(n=2) 

613 split = self.splitByStateFlags(dataIds) 

614 # Iterate over all combinations of different states of DataCoordinate, 

615 # with the same underlying data ID values. 

616 for a0, b0 in itertools.combinations(split.chain(0), 2): 

617 self.assertEqual(a0, b0) 

618 self.assertEqual(a0, b0.byName()) 

619 self.assertEqual(a0.byName(), b0) 

620 # Same thing, for a different data ID value. 

621 for a1, b1 in itertools.combinations(split.chain(1), 2): 

622 self.assertEqual(a1, b1) 

623 self.assertEqual(a1, b1.byName()) 

624 self.assertEqual(a1.byName(), b1) 

625 # Iterate over all combinations of different states of DataCoordinate, 

626 # with different underlying data ID values. 

627 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

628 self.assertNotEqual(a0, b1) 

629 self.assertNotEqual(a1, b0) 

630 self.assertNotEqual(a0, b1.byName()) 

631 self.assertNotEqual(a0.byName(), b1) 

632 self.assertNotEqual(a1, b0.byName()) 

633 self.assertNotEqual(a1.byName(), b0) 

634 

635 def testStandardize(self): 

636 """Test constructing a DataCoordinate from many different kinds of 

637 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

638 """ 

639 for n in range(5): 

640 dimensions = self.randomDimensionSubset() 

641 dataIds = self.randomDataIds(n=1).subset(dimensions) 

642 split = self.splitByStateFlags(dataIds) 

643 for m, dataId in enumerate(split.chain()): 

644 # Passing in any kind of DataCoordinate alone just returns 

645 # that object. 

646 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

647 # Same if we also explicitly pass the dimensions we want. 

648 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

649 # Same if we pass the dimensions and some irrelevant 

650 # kwargs. 

651 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

652 # Test constructing a new data ID from this one with a 

653 # subset of the dimensions. 

654 # This is not possible for some combinations of 

655 # dimensions if hasFull is False (see 

656 # `DataCoordinate.subset` docs). 

657 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

658 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

659 newDataIds = [ 

660 dataId.subset(newDimensions), 

661 DataCoordinate.standardize(dataId, graph=newDimensions), 

662 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

663 ] 

664 for newDataId in newDataIds: 

665 with self.subTest(newDataId=newDataId, type=type(dataId)): 

666 commonKeys = dataId.keys() & newDataId.keys() 

667 self.assertTrue(commonKeys) 

668 self.assertEqual( 

669 [newDataId[k] for k in commonKeys], 

670 [dataId[k] for k in commonKeys], 

671 ) 

672 # This should never "downgrade" from 

673 # Complete to Minimal or Expanded to Complete. 

674 if dataId.hasRecords(): 

675 self.assertTrue(newDataId.hasRecords()) 

676 if dataId.hasFull(): 

677 self.assertTrue(newDataId.hasFull()) 

678 # Start from a complete data ID, and pass its values in via several 

679 # different ways that should be equivalent. 

680 for dataId in split.complete: 

681 # Split the keys (dimension names) into two random subsets, so 

682 # we can pass some as kwargs below. 

683 keys1 = set( 

684 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

685 ) 

686 keys2 = dataId.graph.dimensions.names - keys1 

687 newCompleteDataIds = [ 

688 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

689 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

690 DataCoordinate.standardize( 

691 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

692 ), 

693 DataCoordinate.standardize( 

694 DataCoordinate.makeEmpty(dataId.graph.universe), 

695 graph=dataId.graph, 

696 **dataId.full.byName(), 

697 ), 

698 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

699 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

700 DataCoordinate.standardize( 

701 {k: dataId[k] for k in keys1}, 

702 universe=dataId.universe, 

703 **{k: dataId[k] for k in keys2}, 

704 ), 

705 DataCoordinate.standardize( 

706 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

707 ), 

708 ] 

709 for newDataId in newCompleteDataIds: 

710 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

711 self.assertEqual(dataId, newDataId) 

712 self.assertTrue(newDataId.hasFull()) 

713 

714 def testUnion(self): 

715 """Test `DataCoordinate.union`.""" 

716 # Make test graphs to combine; mostly random, but with a few explicit 

717 # cases to make sure certain edge cases are covered. 

718 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

719 graphs.append(self.allDataIds.universe["visit"].graph) 

720 graphs.append(self.allDataIds.universe["detector"].graph) 

721 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

722 graphs.append(self.allDataIds.universe["band"].graph) 

723 # Iterate over all combinations, including the same graph with itself. 

724 for graph1, graph2 in itertools.product(graphs, repeat=2): 

725 parentDataIds = self.randomDataIds(n=1) 

726 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

727 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

728 (parentDataId,) = parentDataIds 

729 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

730 unioned = lhs.union(rhs) 

731 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

732 self.assertEqual(unioned.graph, graph1.union(graph2)) 

733 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

734 if unioned.hasFull(): 

735 self.assertEqual(unioned.subset(lhs.graph), lhs) 

736 self.assertEqual(unioned.subset(rhs.graph), rhs) 

737 if lhs.hasFull() and rhs.hasFull(): 

738 self.assertTrue(unioned.hasFull()) 

739 if lhs.graph >= unioned.graph and lhs.hasFull(): 

740 self.assertTrue(unioned.hasFull()) 

741 if lhs.hasRecords(): 

742 self.assertTrue(unioned.hasRecords()) 

743 if rhs.graph >= unioned.graph and rhs.hasFull(): 

744 self.assertTrue(unioned.hasFull()) 

745 if rhs.hasRecords(): 

746 self.assertTrue(unioned.hasRecords()) 

747 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

748 self.assertTrue(unioned.hasFull()) 

749 if lhs.hasRecords() and rhs.hasRecords(): 

750 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements: 

751 self.assertTrue(unioned.hasRecords()) 

752 

753 def testRegions(self): 

754 """Test that data IDs for a few known dimensions have the expected 

755 regions. 

756 """ 

757 for dataId in self.randomDataIds(n=4).subset( 

758 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

759 ): 

760 self.assertIsNotNone(dataId.region) 

761 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

762 self.assertEqual(dataId.region, dataId.records["visit"].region) 

763 for dataId in self.randomDataIds(n=4).subset( 

764 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

765 ): 

766 self.assertIsNotNone(dataId.region) 

767 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

768 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

769 for dataId in self.randomDataIds(n=4).subset( 

770 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

771 ): 

772 self.assertIsNotNone(dataId.region) 

773 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

774 self.assertEqual(dataId.region, dataId.records["tract"].region) 

775 for dataId in self.randomDataIds(n=4).subset( 

776 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

777 ): 

778 self.assertIsNotNone(dataId.region) 

779 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

780 self.assertEqual(dataId.region, dataId.records["patch"].region) 

781 for data_id in self.randomDataIds(n=1).subset( 

782 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"]) 

783 ): 

784 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

785 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

786 

787 def testTimespans(self): 

788 """Test that data IDs for a few known dimensions have the expected 

789 timespans. 

790 """ 

791 for dataId in self.randomDataIds(n=4).subset( 

792 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

793 ): 

794 self.assertIsNotNone(dataId.timespan) 

795 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

796 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

797 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

798 # Also test the case for non-temporal DataIds. 

799 for dataId in self.randomDataIds(n=4).subset( 

800 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

801 ): 

802 self.assertIsNone(dataId.timespan) 

803 

804 def testIterableStatusFlags(self): 

805 """Test that DataCoordinateSet and DataCoordinateSequence compute 

806 their hasFull and hasRecords flags correctly from their elements. 

807 """ 

808 dataIds = self.randomDataIds(n=10) 

809 split = self.splitByStateFlags(dataIds) 

810 for cls in (DataCoordinateSet, DataCoordinateSequence): 

811 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

812 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

813 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

814 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

815 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

816 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

817 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

818 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

819 with self.assertRaises(ValueError): 

820 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

821 self.assertEqual( 

822 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

823 ) 

824 self.assertEqual( 

825 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

826 ) 

827 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

828 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

829 with self.assertRaises(ValueError): 

830 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

831 if dataIds.graph.implied: 

832 with self.assertRaises(ValueError): 

833 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

834 

835 def testSetOperations(self): 

836 """Test for self-consistency across DataCoordinateSet's operations.""" 

837 c = self.randomDataIds(n=10).toSet() 

838 a = self.randomDataIds(n=20).toSet() | c 

839 b = self.randomDataIds(n=20).toSet() | c 

840 # Make sure we don't have a particularly unlucky random seed, since 

841 # that would make a lot of this test uninteresting. 

842 self.assertNotEqual(a, b) 

843 self.assertGreater(len(a), 0) 

844 self.assertGreater(len(b), 0) 

845 # The rest of the tests should not depend on the random seed. 

846 self.assertEqual(a, a) 

847 self.assertNotEqual(a, a.toSequence()) 

848 self.assertEqual(a, a.toSequence().toSet()) 

849 self.assertEqual(a, a.toSequence().toSet()) 

850 self.assertEqual(b, b) 

851 self.assertNotEqual(b, b.toSequence()) 

852 self.assertEqual(b, b.toSequence().toSet()) 

853 self.assertEqual(a & b, a.intersection(b)) 

854 self.assertLessEqual(a & b, a) 

855 self.assertLessEqual(a & b, b) 

856 self.assertEqual(a | b, a.union(b)) 

857 self.assertGreaterEqual(a | b, a) 

858 self.assertGreaterEqual(a | b, b) 

859 self.assertEqual(a - b, a.difference(b)) 

860 self.assertLessEqual(a - b, a) 

861 self.assertLessEqual(b - a, b) 

862 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

863 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

864 

865 def testPackers(self): 

866 (instrument_data_id,) = self.allDataIds.subset( 

867 self.allDataIds.universe.extract(["instrument"]) 

868 ).toSet() 

869 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"])) 

870 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.graph) 

871 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

872 self.assertEqual(packed_id, detector_data_id["detector"]) 

873 self.assertEqual(max_bits, packer.maxBits) 

874 self.assertEqual( 

875 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

876 ) 

877 self.assertEqual(packer.pack(detector_data_id), packed_id) 

878 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

879 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

880 

881 

882if __name__ == "__main__": 

883 unittest.main()