Coverage for tests/test_dimensions.py: 11%

444 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-25 15:14 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import math 

25import os 

26import pickle 

27import unittest 

28from collections.abc import Iterator 

29from dataclasses import dataclass 

30from random import Random 

31 

32import lsst.sphgeom 

33from lsst.daf.butler import ( 

34 Config, 

35 DataCoordinate, 

36 DataCoordinateSequence, 

37 DataCoordinateSet, 

38 Dimension, 

39 DimensionConfig, 

40 DimensionGraph, 

41 DimensionPacker, 

42 DimensionUniverse, 

43 NamedKeyDict, 

44 NamedValueSet, 

45 TimespanDatabaseRepresentation, 

46 YamlRepoImportBackend, 

47) 

48from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

49 

50DIMENSION_DATA_FILE = os.path.normpath( 

51 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

52) 

53 

54 

55def loadDimensionData() -> DataCoordinateSequence: 

56 """Load dimension data from an export file included in the code repository. 

57 

58 Returns 

59 ------- 

60 dataIds : `DataCoordinateSet` 

61 A set containing all data IDs in the export file. 

62 """ 

63 # Create an in-memory SQLite database and Registry just to import the YAML 

64 # data and retreive it as a set of DataCoordinate objects. 

65 config = RegistryConfig() 

66 config["db"] = "sqlite://" 

67 registry = _RegistryFactory(config).create_from_config() 

68 with open(DIMENSION_DATA_FILE) as stream: 

69 backend = YamlRepoImportBackend(stream, registry) 

70 backend.register() 

71 backend.load(datastore=None) 

72 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

73 return registry.queryDataIds(dimensions).expanded().toSequence() 

74 

75 

76class ConcreteTestDimensionPacker(DimensionPacker): 

77 """A concrete `DimensionPacker` for testing its base class implementations. 

78 

79 This class just returns the detector ID as-is. 

80 """ 

81 

82 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph): 

83 super().__init__(fixed, dimensions) 

84 self._n_detectors = fixed.records["instrument"].detector_max 

85 self._max_bits = (self._n_detectors - 1).bit_length() 

86 

87 @property 

88 def maxBits(self) -> int: 

89 # Docstring inherited from DimensionPacker.maxBits 

90 return self._max_bits 

91 

92 def _pack(self, dataId: DataCoordinate) -> int: 

93 # Docstring inherited from DimensionPacker._pack 

94 return dataId["detector"] 

95 

96 def unpack(self, packedId: int) -> DataCoordinate: 

97 # Docstring inherited from DimensionPacker.unpack 

98 return DataCoordinate.standardize( 

99 { 

100 "instrument": self.fixed["instrument"], 

101 "detector": packedId, 

102 }, 

103 graph=self.dimensions, 

104 ) 

105 

106 

107class DimensionTestCase(unittest.TestCase): 

108 """Tests for dimensions. 

109 

110 All tests here rely on the content of ``config/dimensions.yaml``, either 

111 to test that the definitions there are read in properly or just as generic 

112 data for testing various operations. 

113 """ 

114 

115 def setUp(self): 

116 self.universe = DimensionUniverse() 

117 

118 def checkGraphInvariants(self, graph): 

119 elements = list(graph.elements) 

120 for n, element in enumerate(elements): 

121 # Ordered comparisons on graphs behave like sets. 

122 self.assertLessEqual(element.graph, graph) 

123 # Ordered comparisons on elements correspond to the ordering within 

124 # a DimensionUniverse (topological, with deterministic 

125 # tiebreakers). 

126 for other in elements[:n]: 

127 self.assertLess(other, element) 

128 self.assertLessEqual(other, element) 

129 for other in elements[n + 1 :]: 

130 self.assertGreater(other, element) 

131 self.assertGreaterEqual(other, element) 

132 if isinstance(element, Dimension): 

133 self.assertEqual(element.graph.required, element.required) 

134 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

135 self.assertCountEqual( 

136 graph.required, 

137 [ 

138 dimension 

139 for dimension in graph.dimensions 

140 if not any(dimension in other.graph.implied for other in graph.elements) 

141 ], 

142 ) 

143 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

144 self.assertCountEqual( 

145 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

146 ) 

147 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

148 # Check primary key traversal order: each element should follow any it 

149 # requires, and element that is implied by any other in the graph 

150 # follow at least one of those. 

151 seen = NamedValueSet() 

152 for element in graph.primaryKeyTraversalOrder: 

153 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

154 seen.add(element) 

155 self.assertLessEqual(element.graph.required, seen) 

156 if element in graph.implied: 

157 self.assertTrue(any(element in s.implied for s in seen)) 

158 self.assertCountEqual(seen, graph.elements) 

159 

160 def testConfigPresent(self): 

161 config = self.universe.dimensionConfig 

162 self.assertIsInstance(config, DimensionConfig) 

163 

164 def testCompatibility(self): 

165 # Simple check that should always be true. 

166 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

167 

168 # Create a universe like the default universe but with a different 

169 # version number. 

170 clone = self.universe.dimensionConfig.copy() 

171 clone["version"] = clone["version"] + 1_000_000 # High version number 

172 universe_clone = DimensionUniverse(config=clone) 

173 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm: 

174 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

175 self.assertIn("differing versions", "\n".join(cm.output)) 

176 

177 # Create completely incompatible universe. 

178 config = Config( 

179 { 

180 "version": 1, 

181 "namespace": "compat_test", 

182 "skypix": { 

183 "common": "htm7", 

184 "htm": { 

185 "class": "lsst.sphgeom.HtmPixelization", 

186 "max_level": 24, 

187 }, 

188 }, 

189 "elements": { 

190 "A": { 

191 "keys": [ 

192 { 

193 "name": "id", 

194 "type": "int", 

195 } 

196 ], 

197 "storage": { 

198 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

199 }, 

200 }, 

201 "B": { 

202 "keys": [ 

203 { 

204 "name": "id", 

205 "type": "int", 

206 } 

207 ], 

208 "storage": { 

209 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

210 }, 

211 }, 

212 }, 

213 "packers": {}, 

214 } 

215 ) 

216 universe2 = DimensionUniverse(config=config) 

217 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

218 

219 def testVersion(self): 

220 self.assertEqual(self.universe.namespace, "daf_butler") 

221 # Test was added starting at version 2. 

222 self.assertGreaterEqual(self.universe.version, 2) 

223 

224 def testConfigRead(self): 

225 self.assertEqual( 

226 set(self.universe.getStaticDimensions().names), 

227 { 

228 "instrument", 

229 "visit", 

230 "visit_system", 

231 "exposure", 

232 "detector", 

233 "physical_filter", 

234 "band", 

235 "subfilter", 

236 "skymap", 

237 "tract", 

238 "patch", 

239 } 

240 | {f"htm{level}" for level in range(25)} 

241 | {f"healpix{level}" for level in range(18)}, 

242 ) 

243 

244 def testGraphs(self): 

245 self.checkGraphInvariants(self.universe.empty) 

246 for element in self.universe.getStaticElements(): 

247 self.checkGraphInvariants(element.graph) 

248 

249 def testInstrumentDimensions(self): 

250 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

251 self.assertCountEqual( 

252 graph.dimensions.names, 

253 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

254 ) 

255 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

256 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

257 self.assertCountEqual( 

258 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

259 ) 

260 self.assertCountEqual(graph.governors.names, {"instrument"}) 

261 

262 def testCalibrationDimensions(self): 

263 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

264 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

265 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

266 self.assertCountEqual(graph.implied.names, ("band",)) 

267 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

268 self.assertCountEqual(graph.governors.names, {"instrument"}) 

269 

270 def testObservationDimensions(self): 

271 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

272 self.assertCountEqual( 

273 graph.dimensions.names, 

274 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

275 ) 

276 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

277 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

278 self.assertCountEqual( 

279 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

280 ) 

281 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

282 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

283 self.assertCountEqual(graph.governors.names, {"instrument"}) 

284 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

285 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

286 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

287 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

288 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"]) 

289 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"]) 

290 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"]) 

291 self.assertEqual( 

292 self.universe.get_elements_populated_by(self.universe["visit"]), 

293 NamedValueSet( 

294 { 

295 self.universe["visit"], 

296 self.universe["visit_definition"], 

297 self.universe["visit_system_membership"], 

298 self.universe["visit_detector_region"], 

299 } 

300 ), 

301 ) 

302 

303 def testSkyMapDimensions(self): 

304 graph = DimensionGraph(self.universe, names=("patch",)) 

305 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

306 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

307 self.assertCountEqual(graph.implied.names, ()) 

308 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

309 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

310 self.assertCountEqual(graph.governors.names, {"skymap"}) 

311 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

312 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

313 

314 def testSubsetCalculation(self): 

315 """Test that independent spatial and temporal options are computed 

316 correctly. 

317 """ 

318 graph = DimensionGraph( 

319 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

320 ) 

321 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

322 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

323 

324 def testSchemaGeneration(self): 

325 tableSpecs = NamedKeyDict({}) 

326 for element in self.universe.getStaticElements(): 

327 if element.hasTable and element.viewOf is None: 

328 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

329 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

330 ) 

331 for element, tableSpec in tableSpecs.items(): 

332 for dep in element.required: 

333 with self.subTest(element=element.name, dep=dep.name): 

334 if dep != element: 

335 self.assertIn(dep.name, tableSpec.fields) 

336 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

337 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

338 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

339 self.assertFalse(tableSpec.fields[dep.name].nullable) 

340 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

341 else: 

342 self.assertIn(element.primaryKey.name, tableSpec.fields) 

343 self.assertEqual( 

344 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

345 ) 

346 self.assertEqual( 

347 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

348 ) 

349 self.assertEqual( 

350 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

351 ) 

352 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

353 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

354 for dep in element.implied: 

355 with self.subTest(element=element.name, dep=dep.name): 

356 self.assertIn(dep.name, tableSpec.fields) 

357 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

358 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

359 for foreignKey in tableSpec.foreignKeys: 

360 self.assertIn(foreignKey.table, tableSpecs) 

361 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

362 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

363 for source, target in zip(foreignKey.source, foreignKey.target, strict=True): 

364 self.assertIn(source, tableSpec.fields.names) 

365 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

366 self.assertEqual( 

367 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

368 ) 

369 self.assertEqual( 

370 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

371 ) 

372 self.assertEqual( 

373 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

374 ) 

375 

376 def testPickling(self): 

377 # Pickling and copying should always yield the exact same object within 

378 # a single process (cross-process is impossible to test here). 

379 universe1 = DimensionUniverse() 

380 universe2 = pickle.loads(pickle.dumps(universe1)) 

381 universe3 = copy.copy(universe1) 

382 universe4 = copy.deepcopy(universe1) 

383 self.assertIs(universe1, universe2) 

384 self.assertIs(universe1, universe3) 

385 self.assertIs(universe1, universe4) 

386 for element1 in universe1.getStaticElements(): 

387 element2 = pickle.loads(pickle.dumps(element1)) 

388 self.assertIs(element1, element2) 

389 graph1 = element1.graph 

390 graph2 = pickle.loads(pickle.dumps(graph1)) 

391 self.assertIs(graph1, graph2) 

392 

393 

394@dataclass 

395class SplitByStateFlags: 

396 """A struct that separates data IDs with different states but the same 

397 values. 

398 """ 

399 

400 minimal: DataCoordinateSequence | None = None 

401 """Data IDs that only contain values for required dimensions. 

402 

403 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

404 if ``minimal.graph.implied`` has no elements. 

405 `DataCoordinate.hasRecords()` will always return `False`. 

406 """ 

407 

408 complete: DataCoordinateSequence | None = None 

409 """Data IDs that contain values for all dimensions. 

410 

411 `DataCoordinateSequence.hasFull()` will always `True` and 

412 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

413 """ 

414 

415 expanded: DataCoordinateSequence | None = None 

416 """Data IDs that contain values for all dimensions as well as records. 

417 

418 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

419 always return `True` for this attribute. 

420 """ 

421 

422 def chain(self, n: int | None = None) -> Iterator: 

423 """Iterate over the data IDs of different types. 

424 

425 Parameters 

426 ---------- 

427 n : `int`, optional 

428 If provided (`None` is default), iterate over only the ``nth`` 

429 data ID in each attribute. 

430 

431 Yields 

432 ------ 

433 dataId : `DataCoordinate` 

434 A data ID from one of the attributes in this struct. 

435 """ 

436 if n is None: 

437 s = slice(None, None) 

438 else: 

439 s = slice(n, n + 1) 

440 if self.minimal is not None: 

441 yield from self.minimal[s] 

442 if self.complete is not None: 

443 yield from self.complete[s] 

444 if self.expanded is not None: 

445 yield from self.expanded[s] 

446 

447 

448class DataCoordinateTestCase(unittest.TestCase): 

449 """Test `DataCoordinate`.""" 

450 

451 RANDOM_SEED = 10 

452 

453 @classmethod 

454 def setUpClass(cls): 

455 cls.allDataIds = loadDimensionData() 

456 

457 def setUp(self): 

458 self.rng = Random(self.RANDOM_SEED) 

459 

460 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None): 

461 """Select random data IDs from those loaded from test data. 

462 

463 Parameters 

464 ---------- 

465 n : `int` 

466 Number of data IDs to select. 

467 dataIds : `DataCoordinateSequence`, optional 

468 Data IDs to select from. Defaults to ``self.allDataIds``. 

469 

470 Returns 

471 ------- 

472 selected : `DataCoordinateSequence` 

473 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

474 """ 

475 if dataIds is None: 

476 dataIds = self.allDataIds 

477 return DataCoordinateSequence( 

478 self.rng.sample(dataIds, n), 

479 graph=dataIds.graph, 

480 hasFull=dataIds.hasFull(), 

481 hasRecords=dataIds.hasRecords(), 

482 check=False, 

483 ) 

484 

485 def randomDimensionSubset(self, n: int = 3, graph: DimensionGraph | None = None) -> DimensionGraph: 

486 """Generate a random `DimensionGraph` that has a subset of the 

487 dimensions in a given one. 

488 

489 Parameters 

490 ---------- 

491 n : `int` 

492 Number of dimensions to select, before automatic expansion by 

493 `DimensionGraph`. 

494 dataIds : `DimensionGraph`, optional 

495 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

496 

497 Returns 

498 ------- 

499 selected : `DimensionGraph` 

500 ``n`` or more dimensions randomly selected from ``graph`` with 

501 replacement. 

502 """ 

503 if graph is None: 

504 graph = self.allDataIds.graph 

505 return DimensionGraph( 

506 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

507 ) 

508 

509 def splitByStateFlags( 

510 self, 

511 dataIds: DataCoordinateSequence | None = None, 

512 *, 

513 expanded: bool = True, 

514 complete: bool = True, 

515 minimal: bool = True, 

516 ) -> SplitByStateFlags: 

517 """Given a sequence of data IDs, generate new equivalent sequences 

518 containing less information. 

519 

520 Parameters 

521 ---------- 

522 dataIds : `DataCoordinateSequence`, optional. 

523 Data IDs to start from. Defaults to ``self.allDataIds``. 

524 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

525 `True`. 

526 expanded : `bool`, optional 

527 If `True` (default) include the original data IDs that contain all 

528 information in the result. 

529 complete : `bool`, optional 

530 If `True` (default) include data IDs for which ``hasFull()`` 

531 returns `True` but ``hasRecords()`` does not. 

532 minimal : `bool`, optional 

533 If `True` (default) include data IDS that only contain values for 

534 required dimensions, for which ``hasFull()`` may not return `True`. 

535 

536 Returns 

537 ------- 

538 split : `SplitByStateFlags` 

539 A dataclass holding the indicated data IDs in attributes that 

540 correspond to the boolean keyword arguments. 

541 """ 

542 if dataIds is None: 

543 dataIds = self.allDataIds 

544 assert dataIds.hasFull() and dataIds.hasRecords() 

545 result = SplitByStateFlags(expanded=dataIds) 

546 if complete: 

547 result.complete = DataCoordinateSequence( 

548 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

549 graph=dataIds.graph, 

550 ) 

551 self.assertTrue(result.complete.hasFull()) 

552 self.assertFalse(result.complete.hasRecords()) 

553 if minimal: 

554 result.minimal = DataCoordinateSequence( 

555 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

556 graph=dataIds.graph, 

557 ) 

558 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

559 self.assertFalse(result.minimal.hasRecords()) 

560 if not expanded: 

561 result.expanded = None 

562 return result 

563 

564 def testMappingInterface(self): 

565 """Test that the mapping interface in `DataCoordinate` and (when 

566 applicable) its ``full`` property are self-consistent and consistent 

567 with the ``graph`` property. 

568 """ 

569 for _ in range(5): 

570 dimensions = self.randomDimensionSubset() 

571 dataIds = self.randomDataIds(n=1).subset(dimensions) 

572 split = self.splitByStateFlags(dataIds) 

573 for dataId in split.chain(): 

574 with self.subTest(dataId=dataId): 

575 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId]) 

576 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId]) 

577 self.assertEqual(dataId.keys(), dataId.graph.required) 

578 for dataId in itertools.chain(split.complete, split.expanded): 

579 with self.subTest(dataId=dataId): 

580 self.assertTrue(dataId.hasFull()) 

581 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

582 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

583 

584 def test_pickle(self): 

585 for _ in range(5): 

586 dimensions = self.randomDimensionSubset() 

587 dataIds = self.randomDataIds(n=1).subset(dimensions) 

588 split = self.splitByStateFlags(dataIds) 

589 for data_id in split.chain(): 

590 s = pickle.dumps(data_id) 

591 read_data_id = pickle.loads(s) 

592 self.assertEqual(data_id, read_data_id) 

593 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

594 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

595 if data_id.hasFull(): 

596 self.assertEqual(data_id.full, read_data_id.full) 

597 if data_id.hasRecords(): 

598 self.assertEqual(data_id.records, read_data_id.records) 

599 

600 def test_record_attributes(self): 

601 """Test that dimension records are available as attributes on expanded 

602 data coordinates. 

603 """ 

604 for _ in range(5): 

605 dimensions = self.randomDimensionSubset() 

606 dataIds = self.randomDataIds(n=1).subset(dimensions) 

607 split = self.splitByStateFlags(dataIds) 

608 for data_id in split.expanded: 

609 for element in data_id.graph.elements: 

610 self.assertIs(getattr(data_id, element.name), data_id.records[element.name]) 

611 self.assertIn(element.name, dir(data_id)) 

612 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

613 data_id.not_a_dimension_name 

614 for data_id in itertools.chain(split.minimal, split.complete): 

615 for element in data_id.graph.elements: 

616 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

617 getattr(data_id, element.name) 

618 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

619 data_id.not_a_dimension_name 

620 

621 def testEquality(self): 

622 """Test that different `DataCoordinate` instances with different state 

623 flags can be compared with each other and other mappings. 

624 """ 

625 dataIds = self.randomDataIds(n=2) 

626 split = self.splitByStateFlags(dataIds) 

627 # Iterate over all combinations of different states of DataCoordinate, 

628 # with the same underlying data ID values. 

629 for a0, b0 in itertools.combinations(split.chain(0), 2): 

630 self.assertEqual(a0, b0) 

631 self.assertEqual(a0, b0.byName()) 

632 self.assertEqual(a0.byName(), b0) 

633 # Same thing, for a different data ID value. 

634 for a1, b1 in itertools.combinations(split.chain(1), 2): 

635 self.assertEqual(a1, b1) 

636 self.assertEqual(a1, b1.byName()) 

637 self.assertEqual(a1.byName(), b1) 

638 # Iterate over all combinations of different states of DataCoordinate, 

639 # with different underlying data ID values. 

640 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

641 self.assertNotEqual(a0, b1) 

642 self.assertNotEqual(a1, b0) 

643 self.assertNotEqual(a0, b1.byName()) 

644 self.assertNotEqual(a0.byName(), b1) 

645 self.assertNotEqual(a1, b0.byName()) 

646 self.assertNotEqual(a1.byName(), b0) 

647 

648 def testStandardize(self): 

649 """Test constructing a DataCoordinate from many different kinds of 

650 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

651 """ 

652 for _ in range(5): 

653 dimensions = self.randomDimensionSubset() 

654 dataIds = self.randomDataIds(n=1).subset(dimensions) 

655 split = self.splitByStateFlags(dataIds) 

656 for dataId in split.chain(): 

657 # Passing in any kind of DataCoordinate alone just returns 

658 # that object. 

659 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

660 # Same if we also explicitly pass the dimensions we want. 

661 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

662 # Same if we pass the dimensions and some irrelevant 

663 # kwargs. 

664 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

665 # Test constructing a new data ID from this one with a 

666 # subset of the dimensions. 

667 # This is not possible for some combinations of 

668 # dimensions if hasFull is False (see 

669 # `DataCoordinate.subset` docs). 

670 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

671 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

672 newDataIds = [ 

673 dataId.subset(newDimensions), 

674 DataCoordinate.standardize(dataId, graph=newDimensions), 

675 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

676 ] 

677 for newDataId in newDataIds: 

678 with self.subTest(newDataId=newDataId, type=type(dataId)): 

679 commonKeys = dataId.keys() & newDataId.keys() 

680 self.assertTrue(commonKeys) 

681 self.assertEqual( 

682 [newDataId[k] for k in commonKeys], 

683 [dataId[k] for k in commonKeys], 

684 ) 

685 # This should never "downgrade" from 

686 # Complete to Minimal or Expanded to Complete. 

687 if dataId.hasRecords(): 

688 self.assertTrue(newDataId.hasRecords()) 

689 if dataId.hasFull(): 

690 self.assertTrue(newDataId.hasFull()) 

691 # Start from a complete data ID, and pass its values in via several 

692 # different ways that should be equivalent. 

693 for dataId in split.complete: 

694 # Split the keys (dimension names) into two random subsets, so 

695 # we can pass some as kwargs below. 

696 keys1 = set( 

697 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

698 ) 

699 keys2 = dataId.graph.dimensions.names - keys1 

700 newCompleteDataIds = [ 

701 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

702 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

703 DataCoordinate.standardize( 

704 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

705 ), 

706 DataCoordinate.standardize( 

707 DataCoordinate.makeEmpty(dataId.graph.universe), 

708 graph=dataId.graph, 

709 **dataId.full.byName(), 

710 ), 

711 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

712 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

713 DataCoordinate.standardize( 

714 {k: dataId[k] for k in keys1}, 

715 universe=dataId.universe, 

716 **{k: dataId[k] for k in keys2}, 

717 ), 

718 DataCoordinate.standardize( 

719 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

720 ), 

721 ] 

722 for newDataId in newCompleteDataIds: 

723 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

724 self.assertEqual(dataId, newDataId) 

725 self.assertTrue(newDataId.hasFull()) 

726 

727 def testUnion(self): 

728 """Test `DataCoordinate.union`.""" 

729 # Make test graphs to combine; mostly random, but with a few explicit 

730 # cases to make sure certain edge cases are covered. 

731 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

732 graphs.append(self.allDataIds.universe["visit"].graph) 

733 graphs.append(self.allDataIds.universe["detector"].graph) 

734 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

735 graphs.append(self.allDataIds.universe["band"].graph) 

736 # Iterate over all combinations, including the same graph with itself. 

737 for graph1, graph2 in itertools.product(graphs, repeat=2): 

738 parentDataIds = self.randomDataIds(n=1) 

739 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

740 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

741 (parentDataId,) = parentDataIds 

742 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

743 unioned = lhs.union(rhs) 

744 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

745 self.assertEqual(unioned.graph, graph1.union(graph2)) 

746 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

747 if unioned.hasFull(): 

748 self.assertEqual(unioned.subset(lhs.graph), lhs) 

749 self.assertEqual(unioned.subset(rhs.graph), rhs) 

750 if lhs.hasFull() and rhs.hasFull(): 

751 self.assertTrue(unioned.hasFull()) 

752 if lhs.graph >= unioned.graph and lhs.hasFull(): 

753 self.assertTrue(unioned.hasFull()) 

754 if lhs.hasRecords(): 

755 self.assertTrue(unioned.hasRecords()) 

756 if rhs.graph >= unioned.graph and rhs.hasFull(): 

757 self.assertTrue(unioned.hasFull()) 

758 if rhs.hasRecords(): 

759 self.assertTrue(unioned.hasRecords()) 

760 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

761 self.assertTrue(unioned.hasFull()) 

762 if ( 

763 lhs.hasRecords() 

764 and rhs.hasRecords() 

765 and lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements 

766 ): 

767 self.assertTrue(unioned.hasRecords()) 

768 

769 def testRegions(self): 

770 """Test that data IDs for a few known dimensions have the expected 

771 regions. 

772 """ 

773 for dataId in self.randomDataIds(n=4).subset( 

774 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

775 ): 

776 self.assertIsNotNone(dataId.region) 

777 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

778 self.assertEqual(dataId.region, dataId.records["visit"].region) 

779 for dataId in self.randomDataIds(n=4).subset( 

780 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

781 ): 

782 self.assertIsNotNone(dataId.region) 

783 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

784 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

785 for dataId in self.randomDataIds(n=4).subset( 

786 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

787 ): 

788 self.assertIsNotNone(dataId.region) 

789 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

790 self.assertEqual(dataId.region, dataId.records["tract"].region) 

791 for dataId in self.randomDataIds(n=4).subset( 

792 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

793 ): 

794 self.assertIsNotNone(dataId.region) 

795 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

796 self.assertEqual(dataId.region, dataId.records["patch"].region) 

797 for data_id in self.randomDataIds(n=1).subset( 

798 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"]) 

799 ): 

800 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

801 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

802 

803 def testTimespans(self): 

804 """Test that data IDs for a few known dimensions have the expected 

805 timespans. 

806 """ 

807 for dataId in self.randomDataIds(n=4).subset( 

808 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

809 ): 

810 self.assertIsNotNone(dataId.timespan) 

811 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

812 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

813 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

814 # Also test the case for non-temporal DataIds. 

815 for dataId in self.randomDataIds(n=4).subset( 

816 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

817 ): 

818 self.assertIsNone(dataId.timespan) 

819 

820 def testIterableStatusFlags(self): 

821 """Test that DataCoordinateSet and DataCoordinateSequence compute 

822 their hasFull and hasRecords flags correctly from their elements. 

823 """ 

824 dataIds = self.randomDataIds(n=10) 

825 split = self.splitByStateFlags(dataIds) 

826 for cls in (DataCoordinateSet, DataCoordinateSequence): 

827 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

828 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

829 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

830 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

831 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

832 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

833 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

834 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

835 with self.assertRaises(ValueError): 

836 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

837 self.assertEqual( 

838 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

839 ) 

840 self.assertEqual( 

841 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

842 ) 

843 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

844 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

845 with self.assertRaises(ValueError): 

846 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

847 if dataIds.graph.implied: 

848 with self.assertRaises(ValueError): 

849 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

850 

851 def testSetOperations(self): 

852 """Test for self-consistency across DataCoordinateSet's operations.""" 

853 c = self.randomDataIds(n=10).toSet() 

854 a = self.randomDataIds(n=20).toSet() | c 

855 b = self.randomDataIds(n=20).toSet() | c 

856 # Make sure we don't have a particularly unlucky random seed, since 

857 # that would make a lot of this test uninteresting. 

858 self.assertNotEqual(a, b) 

859 self.assertGreater(len(a), 0) 

860 self.assertGreater(len(b), 0) 

861 # The rest of the tests should not depend on the random seed. 

862 self.assertEqual(a, a) 

863 self.assertNotEqual(a, a.toSequence()) 

864 self.assertEqual(a, a.toSequence().toSet()) 

865 self.assertEqual(a, a.toSequence().toSet()) 

866 self.assertEqual(b, b) 

867 self.assertNotEqual(b, b.toSequence()) 

868 self.assertEqual(b, b.toSequence().toSet()) 

869 self.assertEqual(a & b, a.intersection(b)) 

870 self.assertLessEqual(a & b, a) 

871 self.assertLessEqual(a & b, b) 

872 self.assertEqual(a | b, a.union(b)) 

873 self.assertGreaterEqual(a | b, a) 

874 self.assertGreaterEqual(a | b, b) 

875 self.assertEqual(a - b, a.difference(b)) 

876 self.assertLessEqual(a - b, a) 

877 self.assertLessEqual(b - a, b) 

878 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

879 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

880 

881 def testPackers(self): 

882 (instrument_data_id,) = self.allDataIds.subset( 

883 self.allDataIds.universe.extract(["instrument"]) 

884 ).toSet() 

885 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"])) 

886 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.graph) 

887 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

888 self.assertEqual(packed_id, detector_data_id["detector"]) 

889 self.assertEqual(max_bits, packer.maxBits) 

890 self.assertEqual( 

891 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

892 ) 

893 self.assertEqual(packer.pack(detector_data_id), packed_id) 

894 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

895 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

896 

897 

898if __name__ == "__main__": 

899 unittest.main()