Coverage for tests/test_dimensions.py: 10%

441 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-09 02:11 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import math 

25import os 

26import pickle 

27import unittest 

28from dataclasses import dataclass 

29from random import Random 

30from typing import Iterator, Optional 

31 

32import lsst.sphgeom 

33from lsst.daf.butler import ( 

34 Config, 

35 DataCoordinate, 

36 DataCoordinateSequence, 

37 DataCoordinateSet, 

38 Dimension, 

39 DimensionConfig, 

40 DimensionGraph, 

41 DimensionPacker, 

42 DimensionUniverse, 

43 NamedKeyDict, 

44 NamedValueSet, 

45 Registry, 

46 TimespanDatabaseRepresentation, 

47 YamlRepoImportBackend, 

48) 

49from lsst.daf.butler.registry import RegistryConfig 

50 

51DIMENSION_DATA_FILE = os.path.normpath( 

52 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

53) 

54 

55 

56def loadDimensionData() -> DataCoordinateSequence: 

57 """Load dimension data from an export file included in the code repository. 

58 

59 Returns 

60 ------- 

61 dataIds : `DataCoordinateSet` 

62 A set containing all data IDs in the export file. 

63 """ 

64 # Create an in-memory SQLite database and Registry just to import the YAML 

65 # data and retreive it as a set of DataCoordinate objects. 

66 config = RegistryConfig() 

67 config["db"] = "sqlite://" 

68 registry = Registry.createFromConfig(config) 

69 with open(DIMENSION_DATA_FILE, "r") as stream: 

70 backend = YamlRepoImportBackend(stream, registry) 

71 backend.register() 

72 backend.load(datastore=None) 

73 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

74 return registry.queryDataIds(dimensions).expanded().toSequence() 

75 

76 

77class TestDimensionPacker(DimensionPacker): 

78 """A concrete `DimensionPacker` for testing its base class implementations. 

79 

80 This class just returns the detector ID as-is. 

81 """ 

82 

83 def __init__(self, fixed: DataCoordinate, dimensions: DimensionGraph): 

84 super().__init__(fixed, dimensions) 

85 self._n_detectors = fixed.records["instrument"].detector_max 

86 self._max_bits = (self._n_detectors - 1).bit_length() 

87 

88 @property 

89 def maxBits(self) -> int: 

90 # Docstring inherited from DimensionPacker.maxBits 

91 return self._max_bits 

92 

93 def _pack(self, dataId: DataCoordinate) -> int: 

94 # Docstring inherited from DimensionPacker._pack 

95 return dataId["detector"] 

96 

97 def unpack(self, packedId: int) -> DataCoordinate: 

98 # Docstring inherited from DimensionPacker.unpack 

99 return DataCoordinate.standardize( 

100 { 

101 "instrument": self.fixed["instrument"], 

102 "detector": packedId, 

103 }, 

104 graph=self.dimensions, 

105 ) 

106 

107 

108class DimensionTestCase(unittest.TestCase): 

109 """Tests for dimensions. 

110 

111 All tests here rely on the content of ``config/dimensions.yaml``, either 

112 to test that the definitions there are read in properly or just as generic 

113 data for testing various operations. 

114 """ 

115 

116 def setUp(self): 

117 self.universe = DimensionUniverse() 

118 

119 def checkGraphInvariants(self, graph): 

120 elements = list(graph.elements) 

121 for n, element in enumerate(elements): 

122 # Ordered comparisons on graphs behave like sets. 

123 self.assertLessEqual(element.graph, graph) 

124 # Ordered comparisons on elements correspond to the ordering within 

125 # a DimensionUniverse (topological, with deterministic 

126 # tiebreakers). 

127 for other in elements[:n]: 

128 self.assertLess(other, element) 

129 self.assertLessEqual(other, element) 

130 for other in elements[n + 1 :]: 

131 self.assertGreater(other, element) 

132 self.assertGreaterEqual(other, element) 

133 if isinstance(element, Dimension): 

134 self.assertEqual(element.graph.required, element.required) 

135 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

136 self.assertCountEqual( 

137 graph.required, 

138 [ 

139 dimension 

140 for dimension in graph.dimensions 

141 if not any(dimension in other.graph.implied for other in graph.elements) 

142 ], 

143 ) 

144 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

145 self.assertCountEqual( 

146 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

147 ) 

148 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

149 # Check primary key traversal order: each element should follow any it 

150 # requires, and element that is implied by any other in the graph 

151 # follow at least one of those. 

152 seen = NamedValueSet() 

153 for element in graph.primaryKeyTraversalOrder: 

154 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

155 seen.add(element) 

156 self.assertLessEqual(element.graph.required, seen) 

157 if element in graph.implied: 

158 self.assertTrue(any(element in s.implied for s in seen)) 

159 self.assertCountEqual(seen, graph.elements) 

160 

161 def testConfigPresent(self): 

162 config = self.universe.dimensionConfig 

163 self.assertIsInstance(config, DimensionConfig) 

164 

165 def testCompatibility(self): 

166 # Simple check that should always be true. 

167 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

168 

169 # Create a universe like the default universe but with a different 

170 # version number. 

171 clone = self.universe.dimensionConfig.copy() 

172 clone["version"] = clone["version"] + 1_000_000 # High version number 

173 universe_clone = DimensionUniverse(config=clone) 

174 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm: 

175 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

176 self.assertIn("differing versions", "\n".join(cm.output)) 

177 

178 # Create completely incompatible universe. 

179 config = Config( 

180 { 

181 "version": 1, 

182 "namespace": "compat_test", 

183 "skypix": { 

184 "common": "htm7", 

185 "htm": { 

186 "class": "lsst.sphgeom.HtmPixelization", 

187 "max_level": 24, 

188 }, 

189 }, 

190 "elements": { 

191 "A": { 

192 "keys": [ 

193 { 

194 "name": "id", 

195 "type": "int", 

196 } 

197 ], 

198 "storage": { 

199 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

200 }, 

201 }, 

202 "B": { 

203 "keys": [ 

204 { 

205 "name": "id", 

206 "type": "int", 

207 } 

208 ], 

209 "storage": { 

210 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

211 }, 

212 }, 

213 }, 

214 "packers": {}, 

215 } 

216 ) 

217 universe2 = DimensionUniverse(config=config) 

218 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

219 

220 def testVersion(self): 

221 self.assertEqual(self.universe.namespace, "daf_butler") 

222 # Test was added starting at version 2. 

223 self.assertGreaterEqual(self.universe.version, 2) 

224 

225 def testConfigRead(self): 

226 self.assertEqual( 

227 set(self.universe.getStaticDimensions().names), 

228 { 

229 "instrument", 

230 "visit", 

231 "visit_system", 

232 "exposure", 

233 "detector", 

234 "physical_filter", 

235 "band", 

236 "subfilter", 

237 "skymap", 

238 "tract", 

239 "patch", 

240 } 

241 | {f"htm{level}" for level in range(25)} 

242 | {f"healpix{level}" for level in range(18)}, 

243 ) 

244 

245 def testGraphs(self): 

246 self.checkGraphInvariants(self.universe.empty) 

247 for element in self.universe.getStaticElements(): 

248 self.checkGraphInvariants(element.graph) 

249 

250 def testInstrumentDimensions(self): 

251 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

252 self.assertCountEqual( 

253 graph.dimensions.names, 

254 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

255 ) 

256 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

257 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

258 self.assertCountEqual( 

259 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

260 ) 

261 self.assertCountEqual(graph.governors.names, {"instrument"}) 

262 

263 def testCalibrationDimensions(self): 

264 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

265 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

266 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

267 self.assertCountEqual(graph.implied.names, ("band",)) 

268 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

269 self.assertCountEqual(graph.governors.names, {"instrument"}) 

270 

271 def testObservationDimensions(self): 

272 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

273 self.assertCountEqual( 

274 graph.dimensions.names, 

275 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

276 ) 

277 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

278 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

279 self.assertCountEqual( 

280 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

281 ) 

282 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

283 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

284 self.assertCountEqual(graph.governors.names, {"instrument"}) 

285 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

286 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

287 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

288 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

289 

290 def testSkyMapDimensions(self): 

291 graph = DimensionGraph(self.universe, names=("patch",)) 

292 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

293 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

294 self.assertCountEqual(graph.implied.names, ()) 

295 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

296 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

297 self.assertCountEqual(graph.governors.names, {"skymap"}) 

298 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

299 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

300 

301 def testSubsetCalculation(self): 

302 """Test that independent spatial and temporal options are computed 

303 correctly. 

304 """ 

305 graph = DimensionGraph( 

306 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

307 ) 

308 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

309 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

310 

311 def testSchemaGeneration(self): 

312 tableSpecs = NamedKeyDict({}) 

313 for element in self.universe.getStaticElements(): 

314 if element.hasTable and element.viewOf is None: 

315 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

316 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

317 ) 

318 for element, tableSpec in tableSpecs.items(): 

319 for dep in element.required: 

320 with self.subTest(element=element.name, dep=dep.name): 

321 if dep != element: 

322 self.assertIn(dep.name, tableSpec.fields) 

323 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

324 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

325 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

326 self.assertFalse(tableSpec.fields[dep.name].nullable) 

327 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

328 else: 

329 self.assertIn(element.primaryKey.name, tableSpec.fields) 

330 self.assertEqual( 

331 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

332 ) 

333 self.assertEqual( 

334 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

335 ) 

336 self.assertEqual( 

337 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

338 ) 

339 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

340 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

341 for dep in element.implied: 

342 with self.subTest(element=element.name, dep=dep.name): 

343 self.assertIn(dep.name, tableSpec.fields) 

344 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

345 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

346 for foreignKey in tableSpec.foreignKeys: 

347 self.assertIn(foreignKey.table, tableSpecs) 

348 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

349 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

350 for source, target in zip(foreignKey.source, foreignKey.target): 

351 self.assertIn(source, tableSpec.fields.names) 

352 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

353 self.assertEqual( 

354 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

355 ) 

356 self.assertEqual( 

357 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

358 ) 

359 self.assertEqual( 

360 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

361 ) 

362 

363 def testPickling(self): 

364 # Pickling and copying should always yield the exact same object within 

365 # a single process (cross-process is impossible to test here). 

366 universe1 = DimensionUniverse() 

367 universe2 = pickle.loads(pickle.dumps(universe1)) 

368 universe3 = copy.copy(universe1) 

369 universe4 = copy.deepcopy(universe1) 

370 self.assertIs(universe1, universe2) 

371 self.assertIs(universe1, universe3) 

372 self.assertIs(universe1, universe4) 

373 for element1 in universe1.getStaticElements(): 

374 element2 = pickle.loads(pickle.dumps(element1)) 

375 self.assertIs(element1, element2) 

376 graph1 = element1.graph 

377 graph2 = pickle.loads(pickle.dumps(graph1)) 

378 self.assertIs(graph1, graph2) 

379 

380 

381@dataclass 

382class SplitByStateFlags: 

383 """A struct that separates data IDs with different states but the same 

384 values. 

385 """ 

386 

387 minimal: Optional[DataCoordinateSequence] = None 

388 """Data IDs that only contain values for required dimensions. 

389 

390 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

391 if ``minimal.graph.implied`` has no elements. 

392 `DataCoordinate.hasRecords()` will always return `False`. 

393 """ 

394 

395 complete: Optional[DataCoordinateSequence] = None 

396 """Data IDs that contain values for all dimensions. 

397 

398 `DataCoordinateSequence.hasFull()` will always `True` and 

399 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

400 """ 

401 

402 expanded: Optional[DataCoordinateSequence] = None 

403 """Data IDs that contain values for all dimensions as well as records. 

404 

405 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

406 always return `True` for this attribute. 

407 """ 

408 

409 def chain(self, n: Optional[int] = None) -> Iterator: 

410 """Iterate over the data IDs of different types. 

411 

412 Parameters 

413 ---------- 

414 n : `int`, optional 

415 If provided (`None` is default), iterate over only the ``nth`` 

416 data ID in each attribute. 

417 

418 Yields 

419 ------ 

420 dataId : `DataCoordinate` 

421 A data ID from one of the attributes in this struct. 

422 """ 

423 if n is None: 

424 s = slice(None, None) 

425 else: 

426 s = slice(n, n + 1) 

427 if self.minimal is not None: 

428 yield from self.minimal[s] 

429 if self.complete is not None: 

430 yield from self.complete[s] 

431 if self.expanded is not None: 

432 yield from self.expanded[s] 

433 

434 

435class DataCoordinateTestCase(unittest.TestCase): 

436 RANDOM_SEED = 10 

437 

438 @classmethod 

439 def setUpClass(cls): 

440 cls.allDataIds = loadDimensionData() 

441 

442 def setUp(self): 

443 self.rng = Random(self.RANDOM_SEED) 

444 

445 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None): 

446 """Select random data IDs from those loaded from test data. 

447 

448 Parameters 

449 ---------- 

450 n : `int` 

451 Number of data IDs to select. 

452 dataIds : `DataCoordinateSequence`, optional 

453 Data IDs to select from. Defaults to ``self.allDataIds``. 

454 

455 Returns 

456 ------- 

457 selected : `DataCoordinateSequence` 

458 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

459 """ 

460 if dataIds is None: 

461 dataIds = self.allDataIds 

462 return DataCoordinateSequence( 

463 self.rng.sample(dataIds, n), 

464 graph=dataIds.graph, 

465 hasFull=dataIds.hasFull(), 

466 hasRecords=dataIds.hasRecords(), 

467 check=False, 

468 ) 

469 

470 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph: 

471 """Generate a random `DimensionGraph` that has a subset of the 

472 dimensions in a given one. 

473 

474 Parameters 

475 ---------- 

476 n : `int` 

477 Number of dimensions to select, before automatic expansion by 

478 `DimensionGraph`. 

479 dataIds : `DimensionGraph`, optional 

480 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

481 

482 Returns 

483 ------- 

484 selected : `DimensionGraph` 

485 ``n`` or more dimensions randomly selected from ``graph`` with 

486 replacement. 

487 """ 

488 if graph is None: 

489 graph = self.allDataIds.graph 

490 return DimensionGraph( 

491 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

492 ) 

493 

494 def splitByStateFlags( 

495 self, 

496 dataIds: Optional[DataCoordinateSequence] = None, 

497 *, 

498 expanded: bool = True, 

499 complete: bool = True, 

500 minimal: bool = True, 

501 ) -> SplitByStateFlags: 

502 """Given a sequence of data IDs, generate new equivalent sequences 

503 containing less information. 

504 

505 Parameters 

506 ---------- 

507 dataIds : `DataCoordinateSequence`, optional. 

508 Data IDs to start from. Defaults to ``self.allDataIds``. 

509 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

510 `True`. 

511 expanded : `bool`, optional 

512 If `True` (default) include the original data IDs that contain all 

513 information in the result. 

514 complete : `bool`, optional 

515 If `True` (default) include data IDs for which ``hasFull()`` 

516 returns `True` but ``hasRecords()`` does not. 

517 minimal : `bool`, optional 

518 If `True` (default) include data IDS that only contain values for 

519 required dimensions, for which ``hasFull()`` may not return `True`. 

520 

521 Returns 

522 ------- 

523 split : `SplitByStateFlags` 

524 A dataclass holding the indicated data IDs in attributes that 

525 correspond to the boolean keyword arguments. 

526 """ 

527 if dataIds is None: 

528 dataIds = self.allDataIds 

529 assert dataIds.hasFull() and dataIds.hasRecords() 

530 result = SplitByStateFlags(expanded=dataIds) 

531 if complete: 

532 result.complete = DataCoordinateSequence( 

533 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

534 graph=dataIds.graph, 

535 ) 

536 self.assertTrue(result.complete.hasFull()) 

537 self.assertFalse(result.complete.hasRecords()) 

538 if minimal: 

539 result.minimal = DataCoordinateSequence( 

540 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

541 graph=dataIds.graph, 

542 ) 

543 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

544 self.assertFalse(result.minimal.hasRecords()) 

545 if not expanded: 

546 result.expanded = None 

547 return result 

548 

549 def testMappingInterface(self): 

550 """Test that the mapping interface in `DataCoordinate` and (when 

551 applicable) its ``full`` property are self-consistent and consistent 

552 with the ``graph`` property. 

553 """ 

554 for n in range(5): 

555 dimensions = self.randomDimensionSubset() 

556 dataIds = self.randomDataIds(n=1).subset(dimensions) 

557 split = self.splitByStateFlags(dataIds) 

558 for dataId in split.chain(): 

559 with self.subTest(dataId=dataId): 

560 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()]) 

561 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()]) 

562 self.assertEqual(dataId.keys(), dataId.graph.required) 

563 for dataId in itertools.chain(split.complete, split.expanded): 

564 with self.subTest(dataId=dataId): 

565 self.assertTrue(dataId.hasFull()) 

566 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

567 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

568 

569 def test_pickle(self): 

570 for n in range(5): 

571 dimensions = self.randomDimensionSubset() 

572 dataIds = self.randomDataIds(n=1).subset(dimensions) 

573 split = self.splitByStateFlags(dataIds) 

574 for data_id in split.chain(): 

575 s = pickle.dumps(data_id) 

576 read_data_id = pickle.loads(s) 

577 self.assertEqual(data_id, read_data_id) 

578 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

579 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

580 if data_id.hasFull(): 

581 self.assertEqual(data_id.full, read_data_id.full) 

582 if data_id.hasRecords(): 

583 self.assertEqual(data_id.records, read_data_id.records) 

584 

585 def test_record_attributes(self): 

586 """Test that dimension records are available as attributes on expanded 

587 data coordinates. 

588 """ 

589 for n in range(5): 

590 dimensions = self.randomDimensionSubset() 

591 dataIds = self.randomDataIds(n=1).subset(dimensions) 

592 split = self.splitByStateFlags(dataIds) 

593 for data_id in split.expanded: 

594 for element in data_id.graph.elements: 

595 self.assertIs(getattr(data_id, element.name), data_id.records[element.name]) 

596 self.assertIn(element.name, dir(data_id)) 

597 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

598 getattr(data_id, "not_a_dimension_name") 

599 for data_id in itertools.chain(split.minimal, split.complete): 

600 for element in data_id.graph.elements: 

601 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

602 getattr(data_id, element.name) 

603 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

604 getattr(data_id, "not_a_dimension_name") 

605 

606 def testEquality(self): 

607 """Test that different `DataCoordinate` instances with different state 

608 flags can be compared with each other and other mappings. 

609 """ 

610 dataIds = self.randomDataIds(n=2) 

611 split = self.splitByStateFlags(dataIds) 

612 # Iterate over all combinations of different states of DataCoordinate, 

613 # with the same underlying data ID values. 

614 for a0, b0 in itertools.combinations(split.chain(0), 2): 

615 self.assertEqual(a0, b0) 

616 self.assertEqual(a0, b0.byName()) 

617 self.assertEqual(a0.byName(), b0) 

618 # Same thing, for a different data ID value. 

619 for a1, b1 in itertools.combinations(split.chain(1), 2): 

620 self.assertEqual(a1, b1) 

621 self.assertEqual(a1, b1.byName()) 

622 self.assertEqual(a1.byName(), b1) 

623 # Iterate over all combinations of different states of DataCoordinate, 

624 # with different underlying data ID values. 

625 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

626 self.assertNotEqual(a0, b1) 

627 self.assertNotEqual(a1, b0) 

628 self.assertNotEqual(a0, b1.byName()) 

629 self.assertNotEqual(a0.byName(), b1) 

630 self.assertNotEqual(a1, b0.byName()) 

631 self.assertNotEqual(a1.byName(), b0) 

632 

633 def testStandardize(self): 

634 """Test constructing a DataCoordinate from many different kinds of 

635 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

636 """ 

637 for n in range(5): 

638 dimensions = self.randomDimensionSubset() 

639 dataIds = self.randomDataIds(n=1).subset(dimensions) 

640 split = self.splitByStateFlags(dataIds) 

641 for m, dataId in enumerate(split.chain()): 

642 # Passing in any kind of DataCoordinate alone just returns 

643 # that object. 

644 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

645 # Same if we also explicitly pass the dimensions we want. 

646 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

647 # Same if we pass the dimensions and some irrelevant 

648 # kwargs. 

649 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

650 # Test constructing a new data ID from this one with a 

651 # subset of the dimensions. 

652 # This is not possible for some combinations of 

653 # dimensions if hasFull is False (see 

654 # `DataCoordinate.subset` docs). 

655 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

656 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

657 newDataIds = [ 

658 dataId.subset(newDimensions), 

659 DataCoordinate.standardize(dataId, graph=newDimensions), 

660 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

661 ] 

662 for newDataId in newDataIds: 

663 with self.subTest(newDataId=newDataId, type=type(dataId)): 

664 commonKeys = dataId.keys() & newDataId.keys() 

665 self.assertTrue(commonKeys) 

666 self.assertEqual( 

667 [newDataId[k] for k in commonKeys], 

668 [dataId[k] for k in commonKeys], 

669 ) 

670 # This should never "downgrade" from 

671 # Complete to Minimal or Expanded to Complete. 

672 if dataId.hasRecords(): 

673 self.assertTrue(newDataId.hasRecords()) 

674 if dataId.hasFull(): 

675 self.assertTrue(newDataId.hasFull()) 

676 # Start from a complete data ID, and pass its values in via several 

677 # different ways that should be equivalent. 

678 for dataId in split.complete: 

679 # Split the keys (dimension names) into two random subsets, so 

680 # we can pass some as kwargs below. 

681 keys1 = set( 

682 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

683 ) 

684 keys2 = dataId.graph.dimensions.names - keys1 

685 newCompleteDataIds = [ 

686 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

687 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

688 DataCoordinate.standardize( 

689 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

690 ), 

691 DataCoordinate.standardize( 

692 DataCoordinate.makeEmpty(dataId.graph.universe), 

693 graph=dataId.graph, 

694 **dataId.full.byName(), 

695 ), 

696 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

697 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

698 DataCoordinate.standardize( 

699 {k: dataId[k] for k in keys1}, 

700 universe=dataId.universe, 

701 **{k: dataId[k] for k in keys2}, 

702 ), 

703 DataCoordinate.standardize( 

704 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

705 ), 

706 ] 

707 for newDataId in newCompleteDataIds: 

708 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

709 self.assertEqual(dataId, newDataId) 

710 self.assertTrue(newDataId.hasFull()) 

711 

712 def testUnion(self): 

713 """Test `DataCoordinate.union`.""" 

714 # Make test graphs to combine; mostly random, but with a few explicit 

715 # cases to make sure certain edge cases are covered. 

716 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

717 graphs.append(self.allDataIds.universe["visit"].graph) 

718 graphs.append(self.allDataIds.universe["detector"].graph) 

719 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

720 graphs.append(self.allDataIds.universe["band"].graph) 

721 # Iterate over all combinations, including the same graph with itself. 

722 for graph1, graph2 in itertools.product(graphs, repeat=2): 

723 parentDataIds = self.randomDataIds(n=1) 

724 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

725 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

726 (parentDataId,) = parentDataIds 

727 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

728 unioned = lhs.union(rhs) 

729 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

730 self.assertEqual(unioned.graph, graph1.union(graph2)) 

731 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

732 if unioned.hasFull(): 

733 self.assertEqual(unioned.subset(lhs.graph), lhs) 

734 self.assertEqual(unioned.subset(rhs.graph), rhs) 

735 if lhs.hasFull() and rhs.hasFull(): 

736 self.assertTrue(unioned.hasFull()) 

737 if lhs.graph >= unioned.graph and lhs.hasFull(): 

738 self.assertTrue(unioned.hasFull()) 

739 if lhs.hasRecords(): 

740 self.assertTrue(unioned.hasRecords()) 

741 if rhs.graph >= unioned.graph and rhs.hasFull(): 

742 self.assertTrue(unioned.hasFull()) 

743 if rhs.hasRecords(): 

744 self.assertTrue(unioned.hasRecords()) 

745 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

746 self.assertTrue(unioned.hasFull()) 

747 if lhs.hasRecords() and rhs.hasRecords(): 

748 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements: 

749 self.assertTrue(unioned.hasRecords()) 

750 

751 def testRegions(self): 

752 """Test that data IDs for a few known dimensions have the expected 

753 regions. 

754 """ 

755 for dataId in self.randomDataIds(n=4).subset( 

756 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

757 ): 

758 self.assertIsNotNone(dataId.region) 

759 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

760 self.assertEqual(dataId.region, dataId.records["visit"].region) 

761 for dataId in self.randomDataIds(n=4).subset( 

762 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

763 ): 

764 self.assertIsNotNone(dataId.region) 

765 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

766 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

767 for dataId in self.randomDataIds(n=4).subset( 

768 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

769 ): 

770 self.assertIsNotNone(dataId.region) 

771 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

772 self.assertEqual(dataId.region, dataId.records["tract"].region) 

773 for dataId in self.randomDataIds(n=4).subset( 

774 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

775 ): 

776 self.assertIsNotNone(dataId.region) 

777 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

778 self.assertEqual(dataId.region, dataId.records["patch"].region) 

779 for data_id in self.randomDataIds(n=1).subset( 

780 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"]) 

781 ): 

782 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

783 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

784 

785 def testTimespans(self): 

786 """Test that data IDs for a few known dimensions have the expected 

787 timespans. 

788 """ 

789 for dataId in self.randomDataIds(n=4).subset( 

790 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

791 ): 

792 self.assertIsNotNone(dataId.timespan) 

793 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

794 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

795 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

796 # Also test the case for non-temporal DataIds. 

797 for dataId in self.randomDataIds(n=4).subset( 

798 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

799 ): 

800 self.assertIsNone(dataId.timespan) 

801 

802 def testIterableStatusFlags(self): 

803 """Test that DataCoordinateSet and DataCoordinateSequence compute 

804 their hasFull and hasRecords flags correctly from their elements. 

805 """ 

806 dataIds = self.randomDataIds(n=10) 

807 split = self.splitByStateFlags(dataIds) 

808 for cls in (DataCoordinateSet, DataCoordinateSequence): 

809 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

810 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

811 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

812 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

813 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

814 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

815 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

816 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

817 with self.assertRaises(ValueError): 

818 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

819 self.assertEqual( 

820 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

821 ) 

822 self.assertEqual( 

823 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

824 ) 

825 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

826 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

827 with self.assertRaises(ValueError): 

828 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

829 if dataIds.graph.implied: 

830 with self.assertRaises(ValueError): 

831 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

832 

833 def testSetOperations(self): 

834 """Test for self-consistency across DataCoordinateSet's operations.""" 

835 c = self.randomDataIds(n=10).toSet() 

836 a = self.randomDataIds(n=20).toSet() | c 

837 b = self.randomDataIds(n=20).toSet() | c 

838 # Make sure we don't have a particularly unlucky random seed, since 

839 # that would make a lot of this test uninteresting. 

840 self.assertNotEqual(a, b) 

841 self.assertGreater(len(a), 0) 

842 self.assertGreater(len(b), 0) 

843 # The rest of the tests should not depend on the random seed. 

844 self.assertEqual(a, a) 

845 self.assertNotEqual(a, a.toSequence()) 

846 self.assertEqual(a, a.toSequence().toSet()) 

847 self.assertEqual(a, a.toSequence().toSet()) 

848 self.assertEqual(b, b) 

849 self.assertNotEqual(b, b.toSequence()) 

850 self.assertEqual(b, b.toSequence().toSet()) 

851 self.assertEqual(a & b, a.intersection(b)) 

852 self.assertLessEqual(a & b, a) 

853 self.assertLessEqual(a & b, b) 

854 self.assertEqual(a | b, a.union(b)) 

855 self.assertGreaterEqual(a | b, a) 

856 self.assertGreaterEqual(a | b, b) 

857 self.assertEqual(a - b, a.difference(b)) 

858 self.assertLessEqual(a - b, a) 

859 self.assertLessEqual(b - a, b) 

860 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

861 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

862 

863 def testPackers(self): 

864 (instrument_data_id,) = self.allDataIds.subset( 

865 self.allDataIds.universe.extract(["instrument"]) 

866 ).toSet() 

867 (detector_data_id,) = self.randomDataIds(n=1).subset(self.allDataIds.universe.extract(["detector"])) 

868 packer = TestDimensionPacker(instrument_data_id, detector_data_id.graph) 

869 packed_id, max_bits = packer.pack(detector_data_id, returnMaxBits=True) 

870 self.assertEqual(packed_id, detector_data_id["detector"]) 

871 self.assertEqual(max_bits, packer.maxBits) 

872 self.assertEqual( 

873 max_bits, math.ceil(math.log2(instrument_data_id.records["instrument"].detector_max)) 

874 ) 

875 self.assertEqual(packer.pack(detector_data_id), packed_id) 

876 self.assertEqual(packer.pack(detector=detector_data_id["detector"]), detector_data_id["detector"]) 

877 self.assertEqual(packer.unpack(packed_id), detector_data_id) 

878 

879 

880if __name__ == "__main__": 

881 unittest.main()