Coverage for tests/test_dimensions.py: 11%

444 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-27 09:44 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import copy 

29import itertools 

30import math 

31import os 

32import pickle 

33import unittest 

34from collections.abc import Iterator 

35from dataclasses import dataclass 

36from random import Random 

37 

38import lsst.sphgeom 

39from lsst.daf.butler import ( 

40 Config, 

41 DataCoordinate, 

42 DataCoordinateSequence, 

43 DataCoordinateSet, 

44 Dimension, 

45 DimensionConfig, 

46 DimensionGraph, 

47 DimensionPacker, 

48 DimensionUniverse, 

49 NamedKeyDict, 

50 NamedValueSet, 

51 TimespanDatabaseRepresentation, 

52 YamlRepoImportBackend, 

53) 

54from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

55 

56DIMENSION_DATA_FILE = os.path.normpath( 

57 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

58) 

59 

60 

61def loadDimensionData() -> DataCoordinateSequence: 

62 """Load dimension data from an export file included in the code repository. 

63 

64 Returns 

65 ------- 

66 dataIds : `DataCoordinateSet` 

67 A set containing all data IDs in the export file. 

68 """ 

69 # Create an in-memory SQLite database and Registry just to import the YAML 

70 # data and retreive it as a set of DataCoordinate objects. 

71 config = RegistryConfig() 

72 config["db"] = "sqlite://" 

73 registry = _RegistryFactory(config).create_from_config() 

74 with open(DIMENSION_DATA_FILE) as stream: 

75 backend = YamlRepoImportBackend(stream, registry) 

76 backend.register() 

77 backend.load(datastore=None) 

78 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

79 return registry.queryDataIds(dimensions).expanded().toSequence() 

80 

81 

82class ConcreteTestDimensionPacker(DimensionPacker): 

83 """A concrete `DimensionPacker` for testing its base class implementations. 

84 

85 This class just returns the detector ID as-is. 

86 """ 

87 

88 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph): 

89 super().__init__(fixed, dimensions) 

90 self._n_detectors = fixed.records["instrument"].detector_max 

91 self._max_bits = (self._n_detectors - 1).bit_length() 

92 

93 @property 

94 def maxBits(self) -> int: 

95 # Docstring inherited from DimensionPacker.maxBits 

96 return self._max_bits 

97 

98 def _pack(self, dataId: DataCoordinate) -> int: 

99 # Docstring inherited from DimensionPacker._pack 

100 return dataId["detector"] 

101 

102 def unpack(self, packedId: int) -> DataCoordinate: 

103 # Docstring inherited from DimensionPacker.unpack 

104 return DataCoordinate.standardize( 

105 { 

106 "instrument": self.fixed["instrument"], 

107 "detector": packedId, 

108 }, 

109 graph=self.dimensions, 

110 ) 

111 

112 

113class DimensionTestCase(unittest.TestCase): 

114 """Tests for dimensions. 

115 

116 All tests here rely on the content of ``config/dimensions.yaml``, either 

117 to test that the definitions there are read in properly or just as generic 

118 data for testing various operations. 

119 """ 

120 

121 def setUp(self): 

122 self.universe = DimensionUniverse() 

123 

124 def checkGraphInvariants(self, graph): 

125 elements = list(graph.elements) 

126 for n, element in enumerate(elements): 

127 # Ordered comparisons on graphs behave like sets. 

128 self.assertLessEqual(element.graph, graph) 

129 # Ordered comparisons on elements correspond to the ordering within 

130 # a DimensionUniverse (topological, with deterministic 

131 # tiebreakers). 

132 for other in elements[:n]: 

133 self.assertLess(other, element) 

134 self.assertLessEqual(other, element) 

135 for other in elements[n + 1 :]: 

136 self.assertGreater(other, element) 

137 self.assertGreaterEqual(other, element) 

138 if isinstance(element, Dimension): 

139 self.assertEqual(element.graph.required, element.required) 

140 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

141 self.assertCountEqual( 

142 graph.required, 

143 [ 

144 dimension 

145 for dimension in graph.dimensions 

146 if not any(dimension in other.graph.implied for other in graph.elements) 

147 ], 

148 ) 

149 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

150 self.assertCountEqual( 

151 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

152 ) 

153 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

154 # Check primary key traversal order: each element should follow any it 

155 # requires, and element that is implied by any other in the graph 

156 # follow at least one of those. 

157 seen = NamedValueSet() 

158 for element in graph.primaryKeyTraversalOrder: 

159 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

160 seen.add(element) 

161 self.assertLessEqual(element.graph.required, seen) 

162 if element in graph.implied: 

163 self.assertTrue(any(element in s.implied for s in seen)) 

164 self.assertCountEqual(seen, graph.elements) 

165 

166 def testConfigPresent(self): 

167 config = self.universe.dimensionConfig 

168 self.assertIsInstance(config, DimensionConfig) 

169 

170 def testCompatibility(self): 

171 # Simple check that should always be true. 

172 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

173 

174 # Create a universe like the default universe but with a different 

175 # version number. 

176 clone = self.universe.dimensionConfig.copy() 

177 clone["version"] = clone["version"] + 1_000_000 # High version number 

178 universe_clone = DimensionUniverse(config=clone) 

179 with self.assertLogs("lsst.daf.butler.dimensions", "INFO") as cm: 

180 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

181 self.assertIn("differing versions", "\n".join(cm.output)) 

182 

183 # Create completely incompatible universe. 

184 config = Config( 

185 { 

186 "version": 1, 

187 "namespace": "compat_test", 

188 "skypix": { 

189 "common": "htm7", 

190 "htm": { 

191 "class": "lsst.sphgeom.HtmPixelization", 

192 "max_level": 24, 

193 }, 

194 }, 

195 "elements": { 

196 "A": { 

197 "keys": [ 

198 { 

199 "name": "id", 

200 "type": "int", 

201 } 

202 ], 

203 "storage": { 

204 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

205 }, 

206 }, 

207 "B": { 

208 "keys": [ 

209 { 

210 "name": "id", 

211 "type": "int", 

212 } 

213 ], 

214 "storage": { 

215 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

216 }, 

217 }, 

218 }, 

219 "packers": {}, 

220 } 

221 ) 

222 universe2 = DimensionUniverse(config=config) 

223 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

224 

225 def testVersion(self): 

226 self.assertEqual(self.universe.namespace, "daf_butler") 

227 # Test was added starting at version 2. 

228 self.assertGreaterEqual(self.universe.version, 2) 

229 

230 def testConfigRead(self): 

231 self.assertEqual( 

232 set(self.universe.getStaticDimensions().names), 

233 { 

234 "instrument", 

235 "visit", 

236 "visit_system", 

237 "exposure", 

238 "detector", 

239 "physical_filter", 

240 "band", 

241 "subfilter", 

242 "skymap", 

243 "tract", 

244 "patch", 

245 } 

246 | {f"htm{level}" for level in range(25)} 

247 | {f"healpix{level}" for level in range(18)}, 

248 ) 

249 

250 def testGraphs(self): 

251 self.checkGraphInvariants(self.universe.empty) 

252 for element in self.universe.getStaticElements(): 

253 self.checkGraphInvariants(element.graph) 

254 

255 def testInstrumentDimensions(self): 

256 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

257 self.assertCountEqual( 

258 graph.dimensions.names, 

259 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

260 ) 

261 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

262 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

263 self.assertCountEqual( 

264 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

265 ) 

266 self.assertCountEqual(graph.governors.names, {"instrument"}) 

267 

268 def testCalibrationDimensions(self): 

269 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

270 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

271 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

272 self.assertCountEqual(graph.implied.names, ("band",)) 

273 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

274 self.assertCountEqual(graph.governors.names, {"instrument"}) 

275 

276 def testObservationDimensions(self): 

277 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

278 self.assertCountEqual( 

279 graph.dimensions.names, 

280 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

281 ) 

282 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

283 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

284 self.assertCountEqual( 

285 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

286 ) 

287 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

288 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

289 self.assertCountEqual(graph.governors.names, {"instrument"}) 

290 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

291 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

292 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

293 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

294 self.assertEqual(self.universe["visit_definition"].populated_by, self.universe["visit"]) 

295 self.assertEqual(self.universe["visit_system_membership"].populated_by, self.universe["visit"]) 

296 self.assertEqual(self.universe["visit_detector_region"].populated_by, self.universe["visit"]) 

297 self.assertEqual( 

298 self.universe.get_elements_populated_by(self.universe["visit"]), 

299 NamedValueSet( 

300 { 

301 self.universe["visit"], 

302 self.universe["visit_definition"], 

303 self.universe["visit_system_membership"], 

304 self.universe["visit_detector_region"], 

305 } 

306 ), 

307 ) 

308 

309 def testSkyMapDimensions(self): 

310 graph = DimensionGraph(self.universe, names=("patch",)) 

311 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

312 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

313 self.assertCountEqual(graph.implied.names, ()) 

314 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

315 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

316 self.assertCountEqual(graph.governors.names, {"skymap"}) 

317 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

318 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

319 

320 def testSubsetCalculation(self): 

321 """Test that independent spatial and temporal options are computed 

322 correctly. 

323 """ 

324 graph = DimensionGraph( 

325 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

326 ) 

327 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

328 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

329 

330 def testSchemaGeneration(self): 

331 tableSpecs = NamedKeyDict({}) 

332 for element in self.universe.getStaticElements(): 

333 if element.hasTable and element.viewOf is None: 

334 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

335 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

336 ) 

337 for element, tableSpec in tableSpecs.items(): 

338 for dep in element.required: 

339 with self.subTest(element=element.name, dep=dep.name): 

340 if dep != element: 

341 self.assertIn(dep.name, tableSpec.fields) 

342 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

343 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

344 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

345 self.assertFalse(tableSpec.fields[dep.name].nullable) 

346 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

347 else: 

348 self.assertIn(element.primaryKey.name, tableSpec.fields) 

349 self.assertEqual( 

350 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

351 ) 

352 self.assertEqual( 

353 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

354 ) 

355 self.assertEqual( 

356 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

357 ) 

358 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

359 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

360 for dep in element.implied: 

361 with self.subTest(element=element.name, dep=dep.name): 

362 self.assertIn(dep.name, tableSpec.fields) 

363 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

364 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

365 for foreignKey in tableSpec.foreignKeys: 

366 self.assertIn(foreignKey.table, tableSpecs) 

367 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

368 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

369 for source, target in zip(foreignKey.source, foreignKey.target, strict=True): 

370 self.assertIn(source, tableSpec.fields.names) 

371 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

372 self.assertEqual( 

373 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

374 ) 

375 self.assertEqual( 

376 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

377 ) 

378 self.assertEqual( 

379 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

380 ) 

381 

382 def testPickling(self): 

383 # Pickling and copying should always yield the exact same object within 

384 # a single process (cross-process is impossible to test here). 

385 universe1 = DimensionUniverse() 

386 universe2 = pickle.loads(pickle.dumps(universe1)) 

387 universe3 = copy.copy(universe1) 

388 universe4 = copy.deepcopy(universe1) 

389 self.assertIs(universe1, universe2) 

390 self.assertIs(universe1, universe3) 

391 self.assertIs(universe1, universe4) 

392 for element1 in universe1.getStaticElements(): 

393 element2 = pickle.loads(pickle.dumps(element1)) 

394 self.assertIs(element1, element2) 

395 graph1 = element1.graph 

396 graph2 = pickle.loads(pickle.dumps(graph1)) 

397 self.assertIs(graph1, graph2) 

398 

399 

400@dataclass 

401class SplitByStateFlags: 

402 """A struct that separates data IDs with different states but the same 

403 values. 

404 """ 

405 

406 minimal: DataCoordinateSequence | None = None 

407 """Data IDs that only contain values for required dimensions. 

408 

409 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

410 if ``minimal.graph.implied`` has no elements. 

411 `DataCoordinate.hasRecords()` will always return `False`. 

412 """ 

413 

414 complete: DataCoordinateSequence | None = None 

415 """Data IDs that contain values for all dimensions. 

416 

417 `DataCoordinateSequence.hasFull()` will always `True` and 

418 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

419 """ 

420 

421 expanded: DataCoordinateSequence | None = None 

422 """Data IDs that contain values for all dimensions as well as records. 

423 

424 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

425 always return `True` for this attribute. 

426 """ 

427 

428 def chain(self, n: int | None = None) -> Iterator: 

429 """Iterate over the data IDs of different types. 

430 

431 Parameters 

432 ---------- 

433 n : `int`, optional 

434 If provided (`None` is default), iterate over only the ``nth`` 

435 data ID in each attribute. 

436 

437 Yields 

438 ------ 

439 dataId : `DataCoordinate` 

440 A data ID from one of the attributes in this struct. 

441 """ 

442 if n is None: 

443 s = slice(None, None) 

444 else: 

445 s = slice(n, n + 1) 

446 if self.minimal is not None: 

447 yield from self.minimal[s] 

448 if self.complete is not None: 

449 yield from self.complete[s] 

450 if self.expanded is not None: 

451 yield from self.expanded[s] 

452 

453 

454class DataCoordinateTestCase(unittest.TestCase): 

455 """Test `DataCoordinate`.""" 

456 

457 RANDOM_SEED = 10 

458 

459 @classmethod 

460 def setUpClass(cls): 

461 cls.allDataIds = loadDimensionData() 

462 

463 def setUp(self): 

464 self.rng = Random(self.RANDOM_SEED) 

465 

466 def randomDataIds(self, n: int, dataIds: DataCoordinateSequence | None = None): 

467 """Select random data IDs from those loaded from test data. 

468 

469 Parameters 

470 ---------- 

471 n : `int` 

472 Number of data IDs to select. 

473 dataIds : `DataCoordinateSequence`, optional 

474 Data IDs to select from. Defaults to ``self.allDataIds``. 

475 

476 Returns 

477 ------- 

478 selected : `DataCoordinateSequence` 

479 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

480 """ 

481 if dataIds is None: 

482 dataIds = self.allDataIds 

483 return DataCoordinateSequence( 

484 self.rng.sample(dataIds, n), 

485 graph=dataIds.graph, 

486 hasFull=dataIds.hasFull(), 

487 hasRecords=dataIds.hasRecords(), 

488 check=False, 

489 ) 

490 

491 def randomDimensionSubset(self, n: int = 3, graph: DimensionGraph | None = None) -> DimensionGraph: 

492 """Generate a random `DimensionGraph` that has a subset of the 

493 dimensions in a given one. 

494 

495 Parameters 

496 ---------- 

497 n : `int` 

498 Number of dimensions to select, before automatic expansion by 

499 `DimensionGraph`. 

500 dataIds : `DimensionGraph`, optional 

501 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

502 

503 Returns 

504 ------- 

505 selected : `DimensionGraph` 

506 ``n`` or more dimensions randomly selected from ``graph`` with 

507 replacement. 

508 """ 

509 if graph is None: 

510 graph = self.allDataIds.graph 

511 return DimensionGraph( 

512 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

513 ) 

514 

515 def splitByStateFlags( 

516 self, 

517 dataIds: DataCoordinateSequence | None = None, 

518 *, 

519 expanded: bool = True, 

520 complete: bool = True, 

521 minimal: bool = True, 

522 ) -> SplitByStateFlags: 

523 """Given a sequence of data IDs, generate new equivalent sequences 

524 containing less information. 

525 

526 Parameters 

527 ---------- 

528 dataIds : `DataCoordinateSequence`, optional. 

529 Data IDs to start from. Defaults to ``self.allDataIds``. 

530 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

531 `True`. 

532 expanded : `bool`, optional 

533 If `True` (default) include the original data IDs that contain all 

534 information in the result. 

535 complete : `bool`, optional 

536 If `True` (default) include data IDs for which ``hasFull()`` 

537 returns `True` but ``hasRecords()`` does not. 

538 minimal : `bool`, optional 

539 If `True` (default) include data IDS that only contain values for 

540 required dimensions, for which ``hasFull()`` may not return `True`. 

541 

542 Returns 

543 ------- 

544 split : `SplitByStateFlags` 

545 A dataclass holding the indicated data IDs in attributes that 

546 correspond to the boolean keyword arguments. 

547 """ 

548 if dataIds is None: 

549 dataIds = self.allDataIds 

550 assert dataIds.hasFull() and dataIds.hasRecords() 

551 result = SplitByStateFlags(expanded=dataIds) 

552 if complete: 

553 result.complete = DataCoordinateSequence( 

554 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

555 graph=dataIds.graph, 

556 ) 

557 self.assertTrue(result.complete.hasFull()) 

558 self.assertFalse(result.complete.hasRecords()) 

559 if minimal: 

560 result.minimal = DataCoordinateSequence( 

561 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

562 graph=dataIds.graph, 

563 ) 

564 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

565 self.assertFalse(result.minimal.hasRecords()) 

566 if not expanded: 

567 result.expanded = None 

568 return result 

569 

570 def testMappingInterface(self): 

571 """Test that the mapping interface in `DataCoordinate` and (when 

572 applicable) its ``full`` property are self-consistent and consistent 

573 with the ``graph`` property. 

574 """ 

575 for _ in range(5): 

576 dimensions = self.randomDimensionSubset() 

577 dataIds = self.randomDataIds(n=1).subset(dimensions) 

578 split = self.splitByStateFlags(dataIds) 

579 for dataId in split.chain(): 

580 with self.subTest(dataId=dataId): 

581 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId]) 

582 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId]) 

583 self.assertEqual(dataId.keys(), dataId.graph.required) 

584 for dataId in itertools.chain(split.complete, split.expanded): 

585 with self.subTest(dataId=dataId): 

586 self.assertTrue(dataId.hasFull()) 

587 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

588 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

589 

590 def test_pickle(self): 

591 for _ in range(5): 

592 dimensions = self.randomDimensionSubset() 

593 dataIds = self.randomDataIds(n=1).subset(dimensions) 

594 split = self.splitByStateFlags(dataIds) 

595 for data_id in split.chain(): 

596 s = pickle.dumps(data_id) 

597 read_data_id = pickle.loads(s) 

598 self.assertEqual(data_id, read_data_id) 

599 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

600 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

601 if data_id.hasFull(): 

602 self.assertEqual(data_id.full, read_data_id.full) 

603 if data_id.hasRecords(): 

604 self.assertEqual(data_id.records, read_data_id.records) 

605 

606 def test_record_attributes(self): 

607 """Test that dimension records are available as attributes on expanded 

608 data coordinates. 

609 """ 

610 for _ in range(5): 

611 dimensions = self.randomDimensionSubset() 

612 dataIds = self.randomDataIds(n=1).subset(dimensions) 

613 split = self.splitByStateFlags(dataIds) 

614 for data_id in split.expanded: 

615 for element in data_id.graph.elements: 

616 self.assertIs(getattr(data_id, element.name), data_id.records[element.name]) 

617 self.assertIn(element.name, dir(data_id)) 

618 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

619 data_id.not_a_dimension_name 

620 for data_id in itertools.chain(split.minimal, split.complete): 

621 for element in data_id.graph.elements: 

622 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

623 getattr(data_id, element.name) 

624 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

625 data_id.not_a_dimension_name 

626 

627 def testEquality(self): 

628 """Test that different `DataCoordinate` instances with different state 

629 flags can be compared with each other and other mappings. 

630 """ 

631 dataIds = self.randomDataIds(n=2) 

632 split = self.splitByStateFlags(dataIds) 

633 # Iterate over all combinations of different states of DataCoordinate, 

634 # with the same underlying data ID values. 

635 for a0, b0 in itertools.combinations(split.chain(0), 2): 

636 self.assertEqual(a0, b0) 

637 self.assertEqual(a0, b0.byName()) 

638 self.assertEqual(a0.byName(), b0) 

639 # Same thing, for a different data ID value. 

640 for a1, b1 in itertools.combinations(split.chain(1), 2): 

641 self.assertEqual(a1, b1) 

642 self.assertEqual(a1, b1.byName()) 

643 self.assertEqual(a1.byName(), b1) 

644 # Iterate over all combinations of different states of DataCoordinate, 

645 # with different underlying data ID values. 

646 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

647 self.assertNotEqual(a0, b1) 

648 self.assertNotEqual(a1, b0) 

649 self.assertNotEqual(a0, b1.byName()) 

650 self.assertNotEqual(a0.byName(), b1) 

651 self.assertNotEqual(a1, b0.byName()) 

652 self.assertNotEqual(a1.byName(), b0) 

653 

654 def testStandardize(self): 

655 """Test constructing a DataCoordinate from many different kinds of 

656 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

657 """ 

658 for _ in range(5): 

659 dimensions = self.randomDimensionSubset() 

660 dataIds = self.randomDataIds(n=1).subset(dimensions) 

661 split = self.splitByStateFlags(dataIds) 

662 for dataId in split.chain(): 

663 # Passing in any kind of DataCoordinate alone just returns 

664 # that object. 

665 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

666 # Same if we also explicitly pass the dimensions we want. 

667 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

668 # Same if we pass the dimensions and some irrelevant 

669 # kwargs. 

670 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

671 # Test constructing a new data ID from this one with a 

672 # subset of the dimensions. 

673 # This is not possible for some combinations of 

674 # dimensions if hasFull is False (see 

675 # `DataCoordinate.subset` docs). 

676 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

677 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

678 newDataIds = [ 

679 dataId.subset(newDimensions), 

680 DataCoordinate.standardize(dataId, graph=newDimensions), 

681 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

682 ] 

683 for newDataId in newDataIds: 

684 with self.subTest(newDataId=newDataId, type=type(dataId)): 

685 commonKeys = dataId.keys() & newDataId.keys() 

686 self.assertTrue(commonKeys) 

687 self.assertEqual( 

688 [newDataId[k] for k in commonKeys], 

689 [dataId[k] for k in commonKeys], 

690 ) 

691 # This should never "downgrade" from 

692 # Complete to Minimal or Expanded to Complete. 

693 if dataId.hasRecords(): 

694 self.assertTrue(newDataId.hasRecords()) 

695 if dataId.hasFull(): 

696 self.assertTrue(newDataId.hasFull()) 

697 # Start from a complete data ID, and pass its values in via several 

698 # different ways that should be equivalent. 

699 for dataId in split.complete: 

700 # Split the keys (dimension names) into two random subsets, so 

701 # we can pass some as kwargs below. 

702 keys1 = set( 

703 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

704 ) 

705 keys2 = dataId.graph.dimensions.names - keys1 

706 newCompleteDataIds = [ 

707 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

708 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

709 DataCoordinate.standardize( 

710 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

711 ), 

712 DataCoordinate.standardize( 

713 DataCoordinate.makeEmpty(dataId.graph.universe), 

714 graph=dataId.graph, 

715 **dataId.full.byName(), 

716 ), 

717 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

718 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

719 DataCoordinate.standardize( 

720 {k: dataId[k] for k in keys1}, 

721 universe=dataId.universe, 

722 **{k: dataId[k] for k in keys2}, 

723 ), 

724 DataCoordinate.standardize( 

725 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

726 ), 

727 ] 

728 for newDataId in newCompleteDataIds: 

729 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

730 self.assertEqual(dataId, newDataId) 

731 self.assertTrue(newDataId.hasFull()) 

732 

733 def testUnion(self): 

734 """Test `DataCoordinate.union`.""" 

735 # Make test graphs to combine; mostly random, but with a few explicit 

736 # cases to make sure certain edge cases are covered. 

737 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

738 graphs.append(self.allDataIds.universe["visit"].graph) 

739 graphs.append(self.allDataIds.universe["detector"].graph) 

740 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

741 graphs.append(self.allDataIds.universe["band"].graph) 

742 # Iterate over all combinations, including the same graph with itself. 

743 for graph1, graph2 in itertools.product(graphs, repeat=2): 

744 parentDataIds = self.randomDataIds(n=1) 

745 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

746 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

747 (parentDataId,) = parentDataIds 

748 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

749 unioned = lhs.union(rhs) 

750 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

751 self.assertEqual(unioned.graph, graph1.union(graph2)) 

752 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

753 if unioned.hasFull(): 

754 self.assertEqual(unioned.subset(lhs.graph), lhs) 

755 self.assertEqual(unioned.subset(rhs.graph), rhs) 

756 if lhs.hasFull() and rhs.hasFull(): 

757 self.assertTrue(unioned.hasFull()) 

758 if lhs.graph >= unioned.graph and lhs.hasFull(): 

759 self.assertTrue(unioned.hasFull()) 

760 if lhs.hasRecords(): 

761 self.assertTrue(unioned.hasRecords()) 

762 if rhs.graph >= unioned.graph and rhs.hasFull(): 

763 self.assertTrue(unioned.hasFull()) 

764 if rhs.hasRecords(): 

765 self.assertTrue(unioned.hasRecords()) 

766 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

767 self.assertTrue(unioned.hasFull()) 

768 if ( 

769 lhs.hasRecords() 

770 and rhs.hasRecords() 

771 and lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements 

772 ): 

773 self.assertTrue(unioned.hasRecords()) 

774 

775 def testRegions(self): 

776 """Test that data IDs for a few known dimensions have the expected 

777 regions. 

778 """ 

779 for dataId in self.randomDataIds(n=4).subset( 

780 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

781 ): 

782 self.assertIsNotNone(dataId.region) 

783 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

784 self.assertEqual(dataId.region, dataId.records["visit"].region) 

785 for dataId in self.randomDataIds(n=4).subset( 

786 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

787 ): 

788 self.assertIsNotNone(dataId.region) 

789 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

790 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

791 for dataId in self.randomDataIds(n=4).subset( 

792 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

793 ): 

794 self.assertIsNotNone(dataId.region) 

795 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

796 self.assertEqual(dataId.region, dataId.records["tract"].region) 

797 for dataId in self.randomDataIds(n=4).subset( 

798 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

799 ): 

800 self.assertIsNotNone(dataId.region) 

801 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

802 self.assertEqual(dataId.region, dataId.records["patch"].region) 

803 for data_id in self.randomDataIds(n=1).subset( 

804 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"]) 

805 ): 

806 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

807 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

808 

809 def testTimespans(self): 

810 """Test that data IDs for a few known dimensions have the expected 

811 timespans. 

812 """ 

813 for dataId in self.randomDataIds(n=4).subset( 

814 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

815 ): 

816 self.assertIsNotNone(dataId.timespan) 

817 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

818 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

819 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

820 # Also test the case for non-temporal DataIds. 

821 for dataId in self.randomDataIds(n=4).subset( 

822 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

823 ): 

824 self.assertIsNone(dataId.timespan) 

825 

826 def testIterableStatusFlags(self): 

827 """Test that DataCoordinateSet and DataCoordinateSequence compute 

828 their hasFull and hasRecords flags correctly from their elements. 

829 """ 

830 dataIds = self.randomDataIds(n=10) 

831 split = self.splitByStateFlags(dataIds) 

832 for cls in (DataCoordinateSet, DataCoordinateSequence): 

833 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

834 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

835 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

836 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

837 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

838 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

839 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

840 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

841 with self.assertRaises(ValueError): 

842 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

843 self.assertEqual( 

844 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

845 ) 

846 self.assertEqual( 

847 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

848 ) 

849 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

850 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

851 with self.assertRaises(ValueError): 

852 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

853 if dataIds.graph.implied: 

854 with self.assertRaises(ValueError): 

855 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

856 

857 def testSetOperations(self): 

858 """Test for self-consistency across DataCoordinateSet's operations.""" 

859 c = self.randomDataIds(n=10).toSet() 

860 a = self.randomDataIds(n=20).toSet() | c 

861 b = self.randomDataIds(n=20).toSet() | c 

862 # Make sure we don't have a particularly unlucky random seed, since 

863 # that would make a lot of this test uninteresting. 

864 self.assertNotEqual(a, b) 

865 self.assertGreater(len(a), 0) 

866 self.assertGreater(len(b), 0) 

867 # The rest of the tests should not depend on the random seed. 

868 self.assertEqual(a, a) 

869 self.assertNotEqual(a, a.toSequence()) 

870 self.assertEqual(a, a.toSequence().toSet()) 

871 self.assertEqual(a, a.toSequence().toSet()) 

872 self.assertEqual(b, b) 

873 self.assertNotEqual(b, b.toSequence()) 

874 self.assertEqual(b, b.toSequence().toSet()) 

875 self.assertEqual(a & b, a.intersection(b)) 

876 self.assertLessEqual(a & b, a) 

877 self.assertLessEqual(a & b, b) 

878 self.assertEqual(a | b, a.union(b)) 

879 self.assertGreaterEqual(a | b, a) 

880 self.assertGreaterEqual(a | b, b) 

881 self.assertEqual(a - b, a.difference(b)) 

882 self.assertLessEqual(a - b, a) 

883 self.assertLessEqual(b - a, b) 

884 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

885 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

886 

887 def testPackers(self): 

888 (instrument_data_id,) = self.allDataIds.subset( 

889 self.allDataIds.universe.extract(["instrument"]) 

890 ).toSet() 

891 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"])) 

892 packer = ConcreteTestDimensionPacker(instrument_data_id, detector_data_id.graph) 

893 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

894 self.assertEqual(packed_id, detector_data_id["detector"]) 

895 self.assertEqual(max_bits, packer.maxBits) 

896 self.assertEqual( 

897 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

898 ) 

899 self.assertEqual(packer.pack(detector_data_id), packed_id) 

900 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

901 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

902 

903 

904if __name__ == "__main__": 

905 unittest.main()