Coverage for tests/test_dimensions.py: 11%

441 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-21 09:55 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import math 

25import os 

26import pickle 

27import unittest 

28from collections.abc import Iterator 

29from dataclasses import dataclass 

30from random import Random 

31 

32import lsst.sphgeom 

33from lsst.daf.butler import ( 

34 Config, 

35 DataCoordinate, 

36 DataCoordinateSequence, 

37 DataCoordinateSet, 

38 Dimension, 

39 DimensionConfig, 

40 DimensionGraph, 

41 DimensionPacker, 

42 DimensionUniverse, 

43 NamedKeyDict, 

44 NamedValueSet, 

45 TimespanDatabaseRepresentation, 

46 YamlRepoImportBackend, 

47) 

48from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

49 

50DIMENSION_DATA_FILE = os.path.normpath( 

51 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

52) 

53 

54 

55def loadDimensionData() -> DataCoordinateSequence: 

56 """Load dimension data from an export file included in the code repository. 

57 

58 Returns 

59 ------- 

60 dataIds : `DataCoordinateSet` 

61 A set containing all data IDs in the export file. 

62 """ 

63 # Create an in-memory SQLite database and Registry just to import the YAML 

64 # data and retreive it as a set of DataCoordinate objects. 

65 config = RegistryConfig() 

66 config["db"] = "sqlite://" 

67 registry = _RegistryFactory(config).create_from_config() 

68 with open(DIMENSION_DATA_FILE) as stream: 

69 backend = YamlRepoImportBackend(stream, registry) 

70 backend.register() 

71 backend.load(datastore=None) 

72 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

73 return registry.queryDataIds(dimensions).expanded().toSequence() 

74 

75 

76class ConcreteTestDimensionPacker(DimensionPacker): 

77 """A concrete `DimensionPacker` for testing its base class implementations. 

78 

79 This class just returns the detector ID as-is. 

80 """ 

81 

82 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph): 

83 super().__init__(fixed, dimensions) 

84 self._n_detectors = fixed.records["instrument"].detector_max 

85 self._max_bits = (self._n_detectors - 1).bit_length() 

86 

87 @property 

88 def maxBits(self) -> int: 

89 # Docstring inherited from DimensionPacker.maxBits 

90 return self._max_bits 

91 

92 def _pack(self, dataId: DataCoordinate) -> int: 

93 # Docstring inherited from DimensionPacker._pack 

94 return dataId["detector"] 

95 

96 def unpack(self, packedId: int) -> DataCoordinate: 

97 # Docstring inherited from DimensionPacker.unpack 

98 return DataCoordinate.standardize( 

99 { 

100 "instrument": self.fixed["instrument"], 

101 "detector": packedId, 

102 }, 

103 graph=self.dimensions, 

104 ) 

105 

106 

107class DimensionTestCase(unittest.TestCase): 

108 """Tests for dimensions. 

109 

110 All tests here rely on the content of ``config/dimensions.yaml``, either 

111 to test that the definitions there are read in properly or just as generic 

112 data for testing various operations. 

113 """ 

114 

115 def setUp(self): 

116 self.universe = DimensionUniverse() 

117 

118 def checkGraphInvariants(self, graph): 

119 elements = list(graph.elements) 

120 for n, element in enumerate(elements): 

121 # Ordered comparisons on graphs behave like sets. 

122 self.assertLessEqual(element.graph, graph) 

123 # Ordered comparisons on elements correspond to the ordering within 

124 # a DimensionUniverse (topological, with deterministic 

125 # tiebreakers). 

126 for other in elements[:n]: 

127 self.assertLess(other, element) 

128 self.assertLessEqual(other, element) 

129 for other in elements[n + 1 :]: 

130 self.assertGreater(other, element) 

131 self.assertGreaterEqual(other, element) 

132 if isinstance(element, Dimension): 

133 self.assertEqual(element.graph.required, element.required) 

134 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

135 self.assertCountEqual( 

136 graph.required, 

137 [ 

138 dimension 

139 for dimension in graph.dimensions 

140 if not any(dimension in other.graph.implied for other in graph.elements) 

141 ], 

142 ) 

143 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

144 self.assertCountEqual( 

145 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

146 ) 

147 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

148 # Check primary key traversal order: each element should follow any it 

149 # requires, and element that is implied by any other in the graph 

150 # follow at least one of those. 

151 seen = NamedValueSet() 

152 for element in graph.primaryKeyTraversalOrder: 

153 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

154 seen.add(element) 

155 self.assertLessEqual(element.graph.required, seen) 

156 if element in graph.implied: 

157 self.assertTrue(any(element in s.implied for s in seen)) 

158 self.assertCountEqual(seen, graph.elements) 

159 

160 def testConfigPresent(self): 

161 config = self.universe.dimensionConfig 

162 self.assertIsInstance(config, DimensionConfig) 

163 

164 def testCompatibility(self): 

165 # Simple check that should always be true. 

166 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

167 

168 # Create a universe like the default universe but with a different 

169 # version number. 

170 clone = self.universe.dimensionConfig.copy() 

171 clone["version"] = clone["version"] + 1_000_000 # High version number 

172 universe_clone = DimensionUniverse(config=clone) 

173 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm: 

174 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

175 self.assertIn("differing versions", "\n".join(cm.output)) 

176 

177 # Create completely incompatible universe. 

178 config = Config( 

179 { 

180 "version": 1, 

181 "namespace": "compat_test", 

182 "skypix": { 

183 "common": "htm7", 

184 "htm": { 

185 "class": "lsst.sphgeom.HtmPixelization", 

186 "max_level": 24, 

187 }, 

188 }, 

189 "elements": { 

190 "A": { 

191 "keys": [ 

192 { 

193 "name": "id", 

194 "type": "int", 

195 } 

196 ], 

197 "storage": { 

198 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

199 }, 

200 }, 

201 "B": { 

202 "keys": [ 

203 { 

204 "name": "id", 

205 "type": "int", 

206 } 

207 ], 

208 "storage": { 

209 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

210 }, 

211 }, 

212 }, 

213 "packers": {}, 

214 } 

215 ) 

216 universe2 = DimensionUniverse(config=config) 

217 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

218 

219 def testVersion(self): 

220 self.assertEqual(self.universe.namespace, "daf_butler") 

221 # Test was added starting at version 2. 

222 self.assertGreaterEqual(self.universe.version, 2) 

223 

224 def testConfigRead(self): 

225 self.assertEqual( 

226 set(self.universe.getStaticDimensions().names), 

227 { 

228 "instrument", 

229 "visit", 

230 "visit_system", 

231 "exposure", 

232 "detector", 

233 "physical_filter", 

234 "band", 

235 "subfilter", 

236 "skymap", 

237 "tract", 

238 "patch", 

239 } 

240 | {f"htm{level}" for level in range(25)} 

241 | {f"healpix{level}" for level in range(18)}, 

242 ) 

243 

244 def testGraphs(self): 

245 self.checkGraphInvariants(self.universe.empty) 

246 for element in self.universe.getStaticElements(): 

247 self.checkGraphInvariants(element.graph) 

248 

249 def testInstrumentDimensions(self): 

250 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

251 self.assertCountEqual( 

252 graph.dimensions.names, 

253 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

254 ) 

255 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

256 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

257 self.assertCountEqual( 

258 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

259 ) 

260 self.assertCountEqual(graph.governors.names, {"instrument"}) 

261 

262 def testCalibrationDimensions(self): 

263 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

264 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

265 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

266 self.assertCountEqual(graph.implied.names, ("band",)) 

267 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

268 self.assertCountEqual(graph.governors.names, {"instrument"}) 

269 

270 def testObservationDimensions(self): 

271 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

272 self.assertCountEqual( 

273 graph.dimensions.names, 

274 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

275 ) 

276 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

277 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

278 self.assertCountEqual( 

279 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

280 ) 

281 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

282 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

283 self.assertCountEqual(graph.governors.names, {"instrument"}) 

284 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

285 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

286 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

287 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

288 

289 def testSkyMapDimensions(self): 

290 graph = DimensionGraph(self.universe, names=("patch",)) 

291 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

292 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

293 self.assertCountEqual(graph.implied.names, ()) 

294 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

295 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

296 self.assertCountEqual(graph.governors.names, {"skymap"}) 

297 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

298 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

299 

300 def testSubsetCalculation(self): 

301 """Test that independent spatial and temporal options are computed 

302 correctly. 

303 """ 

304 graph = DimensionGraph( 

305 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

306 ) 

307 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

308 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

309 

310 def testSchemaGeneration(self): 

311 tableSpecs = NamedKeyDict({}) 

312 for element in self.universe.getStaticElements(): 

313 if element.hasTable and element.viewOf is None: 

314 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

315 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

316 ) 

317 for element, tableSpec in tableSpecs.items(): 

318 for dep in element.required: 

319 with self.subTest(element=element.name, dep=dep.name): 

320 if dep != element: 

321 self.assertIn(dep.name, tableSpec.fields) 

322 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

323 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

324 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

325 self.assertFalse(tableSpec.fields[dep.name].nullable) 

326 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

327 else: 

328 self.assertIn(element.primaryKey.name, tableSpec.fields) 

329 self.assertEqual( 

330 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

331 ) 

332 self.assertEqual( 

333 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

334 ) 

335 self.assertEqual( 

336 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

337 ) 

338 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

339 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

340 for dep in element.implied: 

341 with self.subTest(element=element.name, dep=dep.name): 

342 self.assertIn(dep.name, tableSpec.fields) 

343 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

344 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

345 for foreignKey in tableSpec.foreignKeys: 

346 self.assertIn(foreignKey.table, tableSpecs) 

347 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

348 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

349 for source, target in zip(foreignKey.source, foreignKey.target): 

350 self.assertIn(source, tableSpec.fields.names) 

351 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

352 self.assertEqual( 

353 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

354 ) 

355 self.assertEqual( 

356 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

357 ) 

358 self.assertEqual( 

359 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

360 ) 

361 

362 def testPickling(self): 

363 # Pickling and copying should always yield the exact same object within 

364 # a single process (cross-process is impossible to test here). 

365 universe1 = DimensionUniverse() 

366 universe2 = pickle.loads(pickle.dumps(universe1)) 

367 universe3 = copy.copy(universe1) 

368 universe4 = copy.deepcopy(universe1) 

369 self.assertIs(universe1, universe2) 

370 self.assertIs(universe1, universe3) 

371 self.assertIs(universe1, universe4) 

372 for element1 in universe1.getStaticElements(): 

373 element2 = pickle.loads(pickle.dumps(element1)) 

374 self.assertIs(element1, element2) 

375 graph1 = element1.graph 

376 graph2 = pickle.loads(pickle.dumps(graph1)) 

377 self.assertIs(graph1, graph2) 

378 

379 

380@dataclass 

381class SplitByStateFlags: 

382 """A struct that separates data IDs with different states but the same 

383 values. 

384 """ 

385 

386 minimal: DataCoordinateSequence | None = None 

387 """Data IDs that only contain values for required dimensions. 

388 

389 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

390 if ``minimal.graph.implied`` has no elements. 

391 `DataCoordinate.hasRecords()` will always return `False`. 

392 """ 

393 

394 complete: DataCoordinateSequence | None = None 

395 """Data IDs that contain values for all dimensions. 

396 

397 `DataCoordinateSequence.hasFull()` will always `True` and 

398 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

399 """ 

400 

401 expanded: DataCoordinateSequence | None = None 

402 """Data IDs that contain values for all dimensions as well as records. 

403 

404 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

405 always return `True` for this attribute. 

406 """ 

407 

408 def chain(self, n: int | None = None) -> Iterator: 

409 """Iterate over the data IDs of different types. 

410 

411 Parameters 

412 ---------- 

413 n : `int`, optional 

414 If provided (`None` is default), iterate over only the ``nth`` 

415 data ID in each attribute. 

416 

417 Yields 

418 ------ 

419 dataId : `DataCoordinate` 

420 A data ID from one of the attributes in this struct. 

421 """ 

422 if n is None: 

423 s = slice(None, None) 

424 else: 

425 s = slice(n, n + 1) 

426 if self.minimal is not None: 

427 yield from self.minimal[s] 

428 if self.complete is not None: 

429 yield from self.complete[s] 

430 if self.expanded is not None: 

431 yield from self.expanded[s] 

432 

433 

434class DataCoordinateTestCase(unittest.TestCase): 

435 """Test `DataCoordinate`.""" 

436 

437 RANDOM_SEED = 10 

438 

439 @classmethod 

440 def setUpClass(cls): 

441 cls.allDataIds = loadDimensionData() 

442 

443 def setUp(self): 

444 self.rng = Random(self.RANDOM_SEED) 

445 

446 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None): 

447 """Select random data IDs from those loaded from test data. 

448 

449 Parameters 

450 ---------- 

451 n : `int` 

452 Number of data IDs to select. 

453 dataIds : `DataCoordinateSequence`, optional 

454 Data IDs to select from. Defaults to ``self.allDataIds``. 

455 

456 Returns 

457 ------- 

458 selected : `DataCoordinateSequence` 

459 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

460 """ 

461 if dataIds is None: 

462 dataIds = self.allDataIds 

463 return DataCoordinateSequence( 

464 self.rng.sample(dataIds, n), 

465 graph=dataIds.graph, 

466 hasFull=dataIds.hasFull(), 

467 hasRecords=dataIds.hasRecords(), 

468 check=False, 

469 ) 

470 

471 def randomDimensionSubset(self, n: int = 3, graph: DimensionGraph | None = None) -> DimensionGraph: 

472 """Generate a random `DimensionGraph` that has a subset of the 

473 dimensions in a given one. 

474 

475 Parameters 

476 ---------- 

477 n : `int` 

478 Number of dimensions to select, before automatic expansion by 

479 `DimensionGraph`. 

480 dataIds : `DimensionGraph`, optional 

481 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

482 

483 Returns 

484 ------- 

485 selected : `DimensionGraph` 

486 ``n`` or more dimensions randomly selected from ``graph`` with 

487 replacement. 

488 """ 

489 if graph is None: 

490 graph = self.allDataIds.graph 

491 return DimensionGraph( 

492 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

493 ) 

494 

495 def splitByStateFlags( 

496 self, 

497 dataIds: DataCoordinateSequence | None = None, 

498 *, 

499 expanded: bool = True, 

500 complete: bool = True, 

501 minimal: bool = True, 

502 ) -> SplitByStateFlags: 

503 """Given a sequence of data IDs, generate new equivalent sequences 

504 containing less information. 

505 

506 Parameters 

507 ---------- 

508 dataIds : `DataCoordinateSequence`, optional. 

509 Data IDs to start from. Defaults to ``self.allDataIds``. 

510 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

511 `True`. 

512 expanded : `bool`, optional 

513 If `True` (default) include the original data IDs that contain all 

514 information in the result. 

515 complete : `bool`, optional 

516 If `True` (default) include data IDs for which ``hasFull()`` 

517 returns `True` but ``hasRecords()`` does not. 

518 minimal : `bool`, optional 

519 If `True` (default) include data IDS that only contain values for 

520 required dimensions, for which ``hasFull()`` may not return `True`. 

521 

522 Returns 

523 ------- 

524 split : `SplitByStateFlags` 

525 A dataclass holding the indicated data IDs in attributes that 

526 correspond to the boolean keyword arguments. 

527 """ 

528 if dataIds is None: 

529 dataIds = self.allDataIds 

530 assert dataIds.hasFull() and dataIds.hasRecords() 

531 result = SplitByStateFlags(expanded=dataIds) 

532 if complete: 

533 result.complete = DataCoordinateSequence( 

534 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

535 graph=dataIds.graph, 

536 ) 

537 self.assertTrue(result.complete.hasFull()) 

538 self.assertFalse(result.complete.hasRecords()) 

539 if minimal: 

540 result.minimal = DataCoordinateSequence( 

541 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

542 graph=dataIds.graph, 

543 ) 

544 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

545 self.assertFalse(result.minimal.hasRecords()) 

546 if not expanded: 

547 result.expanded = None 

548 return result 

549 

550 def testMappingInterface(self): 

551 """Test that the mapping interface in `DataCoordinate` and (when 

552 applicable) its ``full`` property are self-consistent and consistent 

553 with the ``graph`` property. 

554 """ 

555 for n in range(5): 

556 dimensions = self.randomDimensionSubset() 

557 dataIds = self.randomDataIds(n=1).subset(dimensions) 

558 split = self.splitByStateFlags(dataIds) 

559 for dataId in split.chain(): 

560 with self.subTest(dataId=dataId): 

561 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()]) 

562 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()]) 

563 self.assertEqual(dataId.keys(), dataId.graph.required) 

564 for dataId in itertools.chain(split.complete, split.expanded): 

565 with self.subTest(dataId=dataId): 

566 self.assertTrue(dataId.hasFull()) 

567 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

568 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

569 

570 def test_pickle(self): 

571 for n in range(5): 

572 dimensions = self.randomDimensionSubset() 

573 dataIds = self.randomDataIds(n=1).subset(dimensions) 

574 split = self.splitByStateFlags(dataIds) 

575 for data_id in split.chain(): 

576 s = pickle.dumps(data_id) 

577 read_data_id = pickle.loads(s) 

578 self.assertEqual(data_id, read_data_id) 

579 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

580 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

581 if data_id.hasFull(): 

582 self.assertEqual(data_id.full, read_data_id.full) 

583 if data_id.hasRecords(): 

584 self.assertEqual(data_id.records, read_data_id.records) 

585 

586 def test_record_attributes(self): 

587 """Test that dimension records are available as attributes on expanded 

588 data coordinates. 

589 """ 

590 for n in range(5): 

591 dimensions = self.randomDimensionSubset() 

592 dataIds = self.randomDataIds(n=1).subset(dimensions) 

593 split = self.splitByStateFlags(dataIds) 

594 for data_id in split.expanded: 

595 for element in data_id.graph.elements: 

596 self.assertIs(getattr(data_id, element.name), data_id.records[element.name]) 

597 self.assertIn(element.name, dir(data_id)) 

598 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

599 getattr(data_id, "not_a_dimension_name") 

600 for data_id in itertools.chain(split.minimal, split.complete): 

601 for element in data_id.graph.elements: 

602 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

603 getattr(data_id, element.name) 

604 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

605 getattr(data_id, "not_a_dimension_name") 

606 

607 def testEquality(self): 

608 """Test that different `DataCoordinate` instances with different state 

609 flags can be compared with each other and other mappings. 

610 """ 

611 dataIds = self.randomDataIds(n=2) 

612 split = self.splitByStateFlags(dataIds) 

613 # Iterate over all combinations of different states of DataCoordinate, 

614 # with the same underlying data ID values. 

615 for a0, b0 in itertools.combinations(split.chain(0), 2): 

616 self.assertEqual(a0, b0) 

617 self.assertEqual(a0, b0.byName()) 

618 self.assertEqual(a0.byName(), b0) 

619 # Same thing, for a different data ID value. 

620 for a1, b1 in itertools.combinations(split.chain(1), 2): 

621 self.assertEqual(a1, b1) 

622 self.assertEqual(a1, b1.byName()) 

623 self.assertEqual(a1.byName(), b1) 

624 # Iterate over all combinations of different states of DataCoordinate, 

625 # with different underlying data ID values. 

626 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

627 self.assertNotEqual(a0, b1) 

628 self.assertNotEqual(a1, b0) 

629 self.assertNotEqual(a0, b1.byName()) 

630 self.assertNotEqual(a0.byName(), b1) 

631 self.assertNotEqual(a1, b0.byName()) 

632 self.assertNotEqual(a1.byName(), b0) 

633 

634 def testStandardize(self): 

635 """Test constructing a DataCoordinate from many different kinds of 

636 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

637 """ 

638 for n in range(5): 

639 dimensions = self.randomDimensionSubset() 

640 dataIds = self.randomDataIds(n=1).subset(dimensions) 

641 split = self.splitByStateFlags(dataIds) 

642 for m, dataId in enumerate(split.chain()): 

643 # Passing in any kind of DataCoordinate alone just returns 

644 # that object. 

645 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

646 # Same if we also explicitly pass the dimensions we want. 

647 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

648 # Same if we pass the dimensions and some irrelevant 

649 # kwargs. 

650 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

651 # Test constructing a new data ID from this one with a 

652 # subset of the dimensions. 

653 # This is not possible for some combinations of 

654 # dimensions if hasFull is False (see 

655 # `DataCoordinate.subset` docs). 

656 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

657 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

658 newDataIds = [ 

659 dataId.subset(newDimensions), 

660 DataCoordinate.standardize(dataId, graph=newDimensions), 

661 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

662 ] 

663 for newDataId in newDataIds: 

664 with self.subTest(newDataId=newDataId, type=type(dataId)): 

665 commonKeys = dataId.keys() & newDataId.keys() 

666 self.assertTrue(commonKeys) 

667 self.assertEqual( 

668 [newDataId[k] for k in commonKeys], 

669 [dataId[k] for k in commonKeys], 

670 ) 

671 # This should never "downgrade" from 

672 # Complete to Minimal or Expanded to Complete. 

673 if dataId.hasRecords(): 

674 self.assertTrue(newDataId.hasRecords()) 

675 if dataId.hasFull(): 

676 self.assertTrue(newDataId.hasFull()) 

677 # Start from a complete data ID, and pass its values in via several 

678 # different ways that should be equivalent. 

679 for dataId in split.complete: 

680 # Split the keys (dimension names) into two random subsets, so 

681 # we can pass some as kwargs below. 

682 keys1 = set( 

683 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

684 ) 

685 keys2 = dataId.graph.dimensions.names - keys1 

686 newCompleteDataIds = [ 

687 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

688 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

689 DataCoordinate.standardize( 

690 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

691 ), 

692 DataCoordinate.standardize( 

693 DataCoordinate.makeEmpty(dataId.graph.universe), 

694 graph=dataId.graph, 

695 **dataId.full.byName(), 

696 ), 

697 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

698 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

699 DataCoordinate.standardize( 

700 {k: dataId[k] for k in keys1}, 

701 universe=dataId.universe, 

702 **{k: dataId[k] for k in keys2}, 

703 ), 

704 DataCoordinate.standardize( 

705 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

706 ), 

707 ] 

708 for newDataId in newCompleteDataIds: 

709 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

710 self.assertEqual(dataId, newDataId) 

711 self.assertTrue(newDataId.hasFull()) 

712 

713 def testUnion(self): 

714 """Test `DataCoordinate.union`.""" 

715 # Make test graphs to combine; mostly random, but with a few explicit 

716 # cases to make sure certain edge cases are covered. 

717 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

718 graphs.append(self.allDataIds.universe["visit"].graph) 

719 graphs.append(self.allDataIds.universe["detector"].graph) 

720 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

721 graphs.append(self.allDataIds.universe["band"].graph) 

722 # Iterate over all combinations, including the same graph with itself. 

723 for graph1, graph2 in itertools.product(graphs, repeat=2): 

724 parentDataIds = self.randomDataIds(n=1) 

725 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

726 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

727 (parentDataId,) = parentDataIds 

728 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

729 unioned = lhs.union(rhs) 

730 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

731 self.assertEqual(unioned.graph, graph1.union(graph2)) 

732 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

733 if unioned.hasFull(): 

734 self.assertEqual(unioned.subset(lhs.graph), lhs) 

735 self.assertEqual(unioned.subset(rhs.graph), rhs) 

736 if lhs.hasFull() and rhs.hasFull(): 

737 self.assertTrue(unioned.hasFull()) 

738 if lhs.graph >= unioned.graph and lhs.hasFull(): 

739 self.assertTrue(unioned.hasFull()) 

740 if lhs.hasRecords(): 

741 self.assertTrue(unioned.hasRecords()) 

742 if rhs.graph >= unioned.graph and rhs.hasFull(): 

743 self.assertTrue(unioned.hasFull()) 

744 if rhs.hasRecords(): 

745 self.assertTrue(unioned.hasRecords()) 

746 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

747 self.assertTrue(unioned.hasFull()) 

748 if lhs.hasRecords() and rhs.hasRecords(): 

749 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements: 

750 self.assertTrue(unioned.hasRecords()) 

751 

752 def testRegions(self): 

753 """Test that data IDs for a few known dimensions have the expected 

754 regions. 

755 """ 

756 for dataId in self.randomDataIds(n=4).subset( 

757 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

758 ): 

759 self.assertIsNotNone(dataId.region) 

760 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

761 self.assertEqual(dataId.region, dataId.records["visit"].region) 

762 for dataId in self.randomDataIds(n=4).subset( 

763 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

764 ): 

765 self.assertIsNotNone(dataId.region) 

766 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

767 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

768 for dataId in self.randomDataIds(n=4).subset( 

769 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

770 ): 

771 self.assertIsNotNone(dataId.region) 

772 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

773 self.assertEqual(dataId.region, dataId.records["tract"].region) 

774 for dataId in self.randomDataIds(n=4).subset( 

775 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

776 ): 

777 self.assertIsNotNone(dataId.region) 

778 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

779 self.assertEqual(dataId.region, dataId.records["patch"].region) 

780 for data_id in self.randomDataIds(n=1).subset( 

781 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"]) 

782 ): 

783 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

784 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

785 

786 def testTimespans(self): 

787 """Test that data IDs for a few known dimensions have the expected 

788 timespans. 

789 """ 

790 for dataId in self.randomDataIds(n=4).subset( 

791 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

792 ): 

793 self.assertIsNotNone(dataId.timespan) 

794 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

795 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

796 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

797 # Also test the case for non-temporal DataIds. 

798 for dataId in self.randomDataIds(n=4).subset( 

799 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

800 ): 

801 self.assertIsNone(dataId.timespan) 

802 

803 def testIterableStatusFlags(self): 

804 """Test that DataCoordinateSet and DataCoordinateSequence compute 

805 their hasFull and hasRecords flags correctly from their elements. 

806 """ 

807 dataIds = self.randomDataIds(n=10) 

808 split = self.splitByStateFlags(dataIds) 

809 for cls in (DataCoordinateSet, DataCoordinateSequence): 

810 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

811 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

812 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

813 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

814 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

815 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

816 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

817 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

818 with self.assertRaises(ValueError): 

819 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

820 self.assertEqual( 

821 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

822 ) 

823 self.assertEqual( 

824 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

825 ) 

826 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

827 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

828 with self.assertRaises(ValueError): 

829 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

830 if dataIds.graph.implied: 

831 with self.assertRaises(ValueError): 

832 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

833 

834 def testSetOperations(self): 

835 """Test for self-consistency across DataCoordinateSet's operations.""" 

836 c = self.randomDataIds(n=10).toSet() 

837 a = self.randomDataIds(n=20).toSet() | c 

838 b = self.randomDataIds(n=20).toSet() | c 

839 # Make sure we don't have a particularly unlucky random seed, since 

840 # that would make a lot of this test uninteresting. 

841 self.assertNotEqual(a, b) 

842 self.assertGreater(len(a), 0) 

843 self.assertGreater(len(b), 0) 

844 # The rest of the tests should not depend on the random seed. 

845 self.assertEqual(a, a) 

846 self.assertNotEqual(a, a.toSequence()) 

847 self.assertEqual(a, a.toSequence().toSet()) 

848 self.assertEqual(a, a.toSequence().toSet()) 

849 self.assertEqual(b, b) 

850 self.assertNotEqual(b, b.toSequence()) 

851 self.assertEqual(b, b.toSequence().toSet()) 

852 self.assertEqual(a & b, a.intersection(b)) 

853 self.assertLessEqual(a & b, a) 

854 self.assertLessEqual(a & b, b) 

855 self.assertEqual(a | b, a.union(b)) 

856 self.assertGreaterEqual(a | b, a) 

857 self.assertGreaterEqual(a | b, b) 

858 self.assertEqual(a - b, a.difference(b)) 

859 self.assertLessEqual(a - b, a) 

860 self.assertLessEqual(b - a, b) 

861 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

862 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

863 

864 def testPackers(self): 

865 (instrument_data_id,) = self.allDataIds.subset( 

866 self.allDataIds.universe.extract(["instrument"]) 

867 ).toSet() 

868 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"])) 

869 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.graph) 

870 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

871 self.assertEqual(packed_id, detector_data_id["detector"]) 

872 self.assertEqual(max_bits, packer.maxBits) 

873 self.assertEqual( 

874 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

875 ) 

876 self.assertEqual(packer.pack(detector_data_id), packed_id) 

877 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

878 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

879 

880 

881if __name__ == "__main__": 

882 unittest.main()