Coverage for tests/test_dimensions.py: 10%

382 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-24 23:50 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import os 

25import pickle 

26import unittest 

27from dataclasses import dataclass 

28from random import Random 

29from typing import Iterator, Optional 

30 

31from lsst.daf.butler import ( 

32 Config, 

33 DataCoordinate, 

34 DataCoordinateSequence, 

35 DataCoordinateSet, 

36 Dimension, 

37 DimensionConfig, 

38 DimensionGraph, 

39 DimensionUniverse, 

40 NamedKeyDict, 

41 NamedValueSet, 

42 Registry, 

43 SpatialRegionDatabaseRepresentation, 

44 TimespanDatabaseRepresentation, 

45 YamlRepoImportBackend, 

46) 

47from lsst.daf.butler.registry import RegistryConfig 

48 

49DIMENSION_DATA_FILE = os.path.normpath( 

50 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

51) 

52 

53 

54def loadDimensionData() -> DataCoordinateSequence: 

55 """Load dimension data from an export file included in the code repository. 

56 

57 Returns 

58 ------- 

59 dataIds : `DataCoordinateSet` 

60 A set containing all data IDs in the export file. 

61 """ 

62 # Create an in-memory SQLite database and Registry just to import the YAML 

63 # data and retreive it as a set of DataCoordinate objects. 

64 config = RegistryConfig() 

65 config["db"] = "sqlite://" 

66 registry = Registry.createFromConfig(config) 

67 with open(DIMENSION_DATA_FILE, "r") as stream: 

68 backend = YamlRepoImportBackend(stream, registry) 

69 backend.register() 

70 backend.load(datastore=None) 

71 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

72 return registry.queryDataIds(dimensions).expanded().toSequence() 

73 

74 

75class DimensionTestCase(unittest.TestCase): 

76 """Tests for dimensions. 

77 

78 All tests here rely on the content of ``config/dimensions.yaml``, either 

79 to test that the definitions there are read in properly or just as generic 

80 data for testing various operations. 

81 """ 

82 

83 def setUp(self): 

84 self.universe = DimensionUniverse() 

85 

86 def checkGraphInvariants(self, graph): 

87 elements = list(graph.elements) 

88 for n, element in enumerate(elements): 

89 # Ordered comparisons on graphs behave like sets. 

90 self.assertLessEqual(element.graph, graph) 

91 # Ordered comparisons on elements correspond to the ordering within 

92 # a DimensionUniverse (topological, with deterministic 

93 # tiebreakers). 

94 for other in elements[:n]: 

95 self.assertLess(other, element) 

96 self.assertLessEqual(other, element) 

97 for other in elements[n + 1 :]: 

98 self.assertGreater(other, element) 

99 self.assertGreaterEqual(other, element) 

100 if isinstance(element, Dimension): 

101 self.assertEqual(element.graph.required, element.required) 

102 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

103 self.assertCountEqual( 

104 graph.required, 

105 [ 

106 dimension 

107 for dimension in graph.dimensions 

108 if not any(dimension in other.graph.implied for other in graph.elements) 

109 ], 

110 ) 

111 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

112 self.assertCountEqual( 

113 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

114 ) 

115 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

116 # Check primary key traversal order: each element should follow any it 

117 # requires, and element that is implied by any other in the graph 

118 # follow at least one of those. 

119 seen = NamedValueSet() 

120 for element in graph.primaryKeyTraversalOrder: 

121 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

122 seen.add(element) 

123 self.assertLessEqual(element.graph.required, seen) 

124 if element in graph.implied: 

125 self.assertTrue(any(element in s.implied for s in seen)) 

126 self.assertCountEqual(seen, graph.elements) 

127 

128 def testConfigPresent(self): 

129 config = self.universe.dimensionConfig 

130 self.assertIsInstance(config, DimensionConfig) 

131 

132 def testCompatibility(self): 

133 # Simple check that should always be true. 

134 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

135 

136 # Create a universe like the default universe but with a different 

137 # version number. 

138 clone = self.universe.dimensionConfig.copy() 

139 clone["version"] = clone["version"] + 1_000_000 # High version number 

140 universe_clone = DimensionUniverse(config=clone) 

141 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm: 

142 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

143 self.assertIn("differing versions", "\n".join(cm.output)) 

144 

145 # Create completely incompatible universe. 

146 config = Config( 

147 { 

148 "version": 1, 

149 "namespace": "compat_test", 

150 "skypix": { 

151 "common": "htm7", 

152 "htm": { 

153 "class": "lsst.sphgeom.HtmPixelization", 

154 "max_level": 24, 

155 }, 

156 }, 

157 "elements": { 

158 "A": { 

159 "keys": [ 

160 { 

161 "name": "id", 

162 "type": "int", 

163 } 

164 ], 

165 "storage": { 

166 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

167 }, 

168 }, 

169 "B": { 

170 "keys": [ 

171 { 

172 "name": "id", 

173 "type": "int", 

174 } 

175 ], 

176 "storage": { 

177 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

178 }, 

179 }, 

180 }, 

181 "packers": {}, 

182 } 

183 ) 

184 universe2 = DimensionUniverse(config=config) 

185 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

186 

187 def testVersion(self): 

188 self.assertEqual(self.universe.namespace, "daf_butler") 

189 # Test was added starting at version 2. 

190 self.assertGreaterEqual(self.universe.version, 2) 

191 

192 def testConfigRead(self): 

193 self.assertEqual( 

194 set(self.universe.getStaticDimensions().names), 

195 { 

196 "instrument", 

197 "visit", 

198 "visit_system", 

199 "exposure", 

200 "detector", 

201 "physical_filter", 

202 "band", 

203 "subfilter", 

204 "skymap", 

205 "tract", 

206 "patch", 

207 } 

208 | {f"htm{level}" for level in range(25)} 

209 | {f"healpix{level}" for level in range(18)}, 

210 ) 

211 

212 def testGraphs(self): 

213 self.checkGraphInvariants(self.universe.empty) 

214 for element in self.universe.getStaticElements(): 

215 self.checkGraphInvariants(element.graph) 

216 

217 def testInstrumentDimensions(self): 

218 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

219 self.assertCountEqual( 

220 graph.dimensions.names, 

221 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

222 ) 

223 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

224 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

225 self.assertCountEqual( 

226 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

227 ) 

228 self.assertCountEqual(graph.governors.names, {"instrument"}) 

229 

230 def testCalibrationDimensions(self): 

231 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

232 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

233 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

234 self.assertCountEqual(graph.implied.names, ("band",)) 

235 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

236 self.assertCountEqual(graph.governors.names, {"instrument"}) 

237 

238 def testObservationDimensions(self): 

239 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

240 self.assertCountEqual( 

241 graph.dimensions.names, 

242 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

243 ) 

244 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

245 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

246 self.assertCountEqual( 

247 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

248 ) 

249 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

250 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

251 self.assertCountEqual(graph.governors.names, {"instrument"}) 

252 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

253 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

254 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

255 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

256 

257 def testSkyMapDimensions(self): 

258 graph = DimensionGraph(self.universe, names=("patch",)) 

259 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

260 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

261 self.assertCountEqual(graph.implied.names, ()) 

262 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

263 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

264 self.assertCountEqual(graph.governors.names, {"skymap"}) 

265 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

266 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

267 

268 def testSubsetCalculation(self): 

269 """Test that independent spatial and temporal options are computed 

270 correctly. 

271 """ 

272 graph = DimensionGraph( 

273 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

274 ) 

275 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

276 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

277 

278 def testSchemaGeneration(self): 

279 tableSpecs = NamedKeyDict({}) 

280 for element in self.universe.getStaticElements(): 

281 if element.hasTable and element.viewOf is None: 

282 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

283 RegionReprClass=SpatialRegionDatabaseRepresentation, 

284 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

285 ) 

286 for element, tableSpec in tableSpecs.items(): 

287 for dep in element.required: 

288 with self.subTest(element=element.name, dep=dep.name): 

289 if dep != element: 

290 self.assertIn(dep.name, tableSpec.fields) 

291 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

292 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

293 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

294 self.assertFalse(tableSpec.fields[dep.name].nullable) 

295 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

296 else: 

297 self.assertIn(element.primaryKey.name, tableSpec.fields) 

298 self.assertEqual( 

299 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

300 ) 

301 self.assertEqual( 

302 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

303 ) 

304 self.assertEqual( 

305 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

306 ) 

307 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

308 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

309 for dep in element.implied: 

310 with self.subTest(element=element.name, dep=dep.name): 

311 self.assertIn(dep.name, tableSpec.fields) 

312 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

313 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

314 for foreignKey in tableSpec.foreignKeys: 

315 self.assertIn(foreignKey.table, tableSpecs) 

316 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

317 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

318 for source, target in zip(foreignKey.source, foreignKey.target): 

319 self.assertIn(source, tableSpec.fields.names) 

320 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

321 self.assertEqual( 

322 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

323 ) 

324 self.assertEqual( 

325 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

326 ) 

327 self.assertEqual( 

328 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

329 ) 

330 

331 def testPickling(self): 

332 # Pickling and copying should always yield the exact same object within 

333 # a single process (cross-process is impossible to test here). 

334 universe1 = DimensionUniverse() 

335 universe2 = pickle.loads(pickle.dumps(universe1)) 

336 universe3 = copy.copy(universe1) 

337 universe4 = copy.deepcopy(universe1) 

338 self.assertIs(universe1, universe2) 

339 self.assertIs(universe1, universe3) 

340 self.assertIs(universe1, universe4) 

341 for element1 in universe1.getStaticElements(): 

342 element2 = pickle.loads(pickle.dumps(element1)) 

343 self.assertIs(element1, element2) 

344 graph1 = element1.graph 

345 graph2 = pickle.loads(pickle.dumps(graph1)) 

346 self.assertIs(graph1, graph2) 

347 

348 

349@dataclass 

350class SplitByStateFlags: 

351 """A struct that separates data IDs with different states but the same 

352 values. 

353 """ 

354 

355 minimal: Optional[DataCoordinateSequence] = None 

356 """Data IDs that only contain values for required dimensions. 

357 

358 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

359 if ``minimal.graph.implied`` has no elements. 

360 `DataCoordinate.hasRecords()` will always return `False`. 

361 """ 

362 

363 complete: Optional[DataCoordinateSequence] = None 

364 """Data IDs that contain values for all dimensions. 

365 

366 `DataCoordinateSequence.hasFull()` will always `True` and 

367 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

368 """ 

369 

370 expanded: Optional[DataCoordinateSequence] = None 

371 """Data IDs that contain values for all dimensions as well as records. 

372 

373 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

374 always return `True` for this attribute. 

375 """ 

376 

377 def chain(self, n: Optional[int] = None) -> Iterator: 

378 """Iterate over the data IDs of different types. 

379 

380 Parameters 

381 ---------- 

382 n : `int`, optional 

383 If provided (`None` is default), iterate over only the ``nth`` 

384 data ID in each attribute. 

385 

386 Yields 

387 ------ 

388 dataId : `DataCoordinate` 

389 A data ID from one of the attributes in this struct. 

390 """ 

391 if n is None: 

392 s = slice(None, None) 

393 else: 

394 s = slice(n, n + 1) 

395 if self.minimal is not None: 

396 yield from self.minimal[s] 

397 if self.complete is not None: 

398 yield from self.complete[s] 

399 if self.expanded is not None: 

400 yield from self.expanded[s] 

401 

402 

403class DataCoordinateTestCase(unittest.TestCase): 

404 RANDOM_SEED = 10 

405 

406 @classmethod 

407 def setUpClass(cls): 

408 cls.allDataIds = loadDimensionData() 

409 

410 def setUp(self): 

411 self.rng = Random(self.RANDOM_SEED) 

412 

413 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None): 

414 """Select random data IDs from those loaded from test data. 

415 

416 Parameters 

417 ---------- 

418 n : `int` 

419 Number of data IDs to select. 

420 dataIds : `DataCoordinateSequence`, optional 

421 Data IDs to select from. Defaults to ``self.allDataIds``. 

422 

423 Returns 

424 ------- 

425 selected : `DataCoordinateSequence` 

426 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

427 """ 

428 if dataIds is None: 

429 dataIds = self.allDataIds 

430 return DataCoordinateSequence( 

431 self.rng.sample(dataIds, n), 

432 graph=dataIds.graph, 

433 hasFull=dataIds.hasFull(), 

434 hasRecords=dataIds.hasRecords(), 

435 check=False, 

436 ) 

437 

438 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph: 

439 """Generate a random `DimensionGraph` that has a subset of the 

440 dimensions in a given one. 

441 

442 Parameters 

443 ---------- 

444 n : `int` 

445 Number of dimensions to select, before automatic expansion by 

446 `DimensionGraph`. 

447 dataIds : `DimensionGraph`, optional 

448 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

449 

450 Returns 

451 ------- 

452 selected : `DimensionGraph` 

453 ``n`` or more dimensions randomly selected from ``graph`` with 

454 replacement. 

455 """ 

456 if graph is None: 

457 graph = self.allDataIds.graph 

458 return DimensionGraph( 

459 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

460 ) 

461 

462 def splitByStateFlags( 

463 self, 

464 dataIds: Optional[DataCoordinateSequence] = None, 

465 *, 

466 expanded: bool = True, 

467 complete: bool = True, 

468 minimal: bool = True, 

469 ) -> SplitByStateFlags: 

470 """Given a sequence of data IDs, generate new equivalent sequences 

471 containing less information. 

472 

473 Parameters 

474 ---------- 

475 dataIds : `DataCoordinateSequence`, optional. 

476 Data IDs to start from. Defaults to ``self.allDataIds``. 

477 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

478 `True`. 

479 expanded : `bool`, optional 

480 If `True` (default) include the original data IDs that contain all 

481 information in the result. 

482 complete : `bool`, optional 

483 If `True` (default) include data IDs for which ``hasFull()`` 

484 returns `True` but ``hasRecords()`` does not. 

485 minimal : `bool`, optional 

486 If `True` (default) include data IDS that only contain values for 

487 required dimensions, for which ``hasFull()`` may not return `True`. 

488 

489 Returns 

490 ------- 

491 split : `SplitByStateFlags` 

492 A dataclass holding the indicated data IDs in attributes that 

493 correspond to the boolean keyword arguments. 

494 """ 

495 if dataIds is None: 

496 dataIds = self.allDataIds 

497 assert dataIds.hasFull() and dataIds.hasRecords() 

498 result = SplitByStateFlags(expanded=dataIds) 

499 if complete: 

500 result.complete = DataCoordinateSequence( 

501 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

502 graph=dataIds.graph, 

503 ) 

504 self.assertTrue(result.complete.hasFull()) 

505 self.assertFalse(result.complete.hasRecords()) 

506 if minimal: 

507 result.minimal = DataCoordinateSequence( 

508 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

509 graph=dataIds.graph, 

510 ) 

511 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

512 self.assertFalse(result.minimal.hasRecords()) 

513 if not expanded: 

514 result.expanded = None 

515 return result 

516 

517 def testMappingInterface(self): 

518 """Test that the mapping interface in `DataCoordinate` and (when 

519 applicable) its ``full`` property are self-consistent and consistent 

520 with the ``graph`` property. 

521 """ 

522 for n in range(5): 

523 dimensions = self.randomDimensionSubset() 

524 dataIds = self.randomDataIds(n=1).subset(dimensions) 

525 split = self.splitByStateFlags(dataIds) 

526 for dataId in split.chain(): 

527 with self.subTest(dataId=dataId): 

528 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()]) 

529 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()]) 

530 self.assertEqual(dataId.keys(), dataId.graph.required) 

531 for dataId in itertools.chain(split.complete, split.expanded): 

532 with self.subTest(dataId=dataId): 

533 self.assertTrue(dataId.hasFull()) 

534 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

535 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

536 

537 def testEquality(self): 

538 """Test that different `DataCoordinate` instances with different state 

539 flags can be compared with each other and other mappings. 

540 """ 

541 dataIds = self.randomDataIds(n=2) 

542 split = self.splitByStateFlags(dataIds) 

543 # Iterate over all combinations of different states of DataCoordinate, 

544 # with the same underlying data ID values. 

545 for a0, b0 in itertools.combinations(split.chain(0), 2): 

546 self.assertEqual(a0, b0) 

547 self.assertEqual(a0, b0.byName()) 

548 self.assertEqual(a0.byName(), b0) 

549 # Same thing, for a different data ID value. 

550 for a1, b1 in itertools.combinations(split.chain(1), 2): 

551 self.assertEqual(a1, b1) 

552 self.assertEqual(a1, b1.byName()) 

553 self.assertEqual(a1.byName(), b1) 

554 # Iterate over all combinations of different states of DataCoordinate, 

555 # with different underlying data ID values. 

556 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

557 self.assertNotEqual(a0, b1) 

558 self.assertNotEqual(a1, b0) 

559 self.assertNotEqual(a0, b1.byName()) 

560 self.assertNotEqual(a0.byName(), b1) 

561 self.assertNotEqual(a1, b0.byName()) 

562 self.assertNotEqual(a1.byName(), b0) 

563 

564 def testStandardize(self): 

565 """Test constructing a DataCoordinate from many different kinds of 

566 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

567 """ 

568 for n in range(5): 

569 dimensions = self.randomDimensionSubset() 

570 dataIds = self.randomDataIds(n=1).subset(dimensions) 

571 split = self.splitByStateFlags(dataIds) 

572 for m, dataId in enumerate(split.chain()): 

573 # Passing in any kind of DataCoordinate alone just returns 

574 # that object. 

575 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

576 # Same if we also explicitly pass the dimensions we want. 

577 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

578 # Same if we pass the dimensions and some irrelevant 

579 # kwargs. 

580 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

581 # Test constructing a new data ID from this one with a 

582 # subset of the dimensions. 

583 # This is not possible for some combinations of 

584 # dimensions if hasFull is False (see 

585 # `DataCoordinate.subset` docs). 

586 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

587 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

588 newDataIds = [ 

589 dataId.subset(newDimensions), 

590 DataCoordinate.standardize(dataId, graph=newDimensions), 

591 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

592 ] 

593 for newDataId in newDataIds: 

594 with self.subTest(newDataId=newDataId, type=type(dataId)): 

595 commonKeys = dataId.keys() & newDataId.keys() 

596 self.assertTrue(commonKeys) 

597 self.assertEqual( 

598 [newDataId[k] for k in commonKeys], 

599 [dataId[k] for k in commonKeys], 

600 ) 

601 # This should never "downgrade" from 

602 # Complete to Minimal or Expanded to Complete. 

603 if dataId.hasRecords(): 

604 self.assertTrue(newDataId.hasRecords()) 

605 if dataId.hasFull(): 

606 self.assertTrue(newDataId.hasFull()) 

607 # Start from a complete data ID, and pass its values in via several 

608 # different ways that should be equivalent. 

609 for dataId in split.complete: 

610 # Split the keys (dimension names) into two random subsets, so 

611 # we can pass some as kwargs below. 

612 keys1 = set( 

613 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

614 ) 

615 keys2 = dataId.graph.dimensions.names - keys1 

616 newCompleteDataIds = [ 

617 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

618 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

619 DataCoordinate.standardize( 

620 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

621 ), 

622 DataCoordinate.standardize( 

623 DataCoordinate.makeEmpty(dataId.graph.universe), 

624 graph=dataId.graph, 

625 **dataId.full.byName(), 

626 ), 

627 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

628 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

629 DataCoordinate.standardize( 

630 {k: dataId[k] for k in keys1}, 

631 universe=dataId.universe, 

632 **{k: dataId[k] for k in keys2}, 

633 ), 

634 DataCoordinate.standardize( 

635 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

636 ), 

637 ] 

638 for newDataId in newCompleteDataIds: 

639 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

640 self.assertEqual(dataId, newDataId) 

641 self.assertTrue(newDataId.hasFull()) 

642 

643 def testUnion(self): 

644 """Test `DataCoordinate.union`.""" 

645 # Make test graphs to combine; mostly random, but with a few explicit 

646 # cases to make sure certain edge cases are covered. 

647 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

648 graphs.append(self.allDataIds.universe["visit"].graph) 

649 graphs.append(self.allDataIds.universe["detector"].graph) 

650 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

651 graphs.append(self.allDataIds.universe["band"].graph) 

652 # Iterate over all combinations, including the same graph with itself. 

653 for graph1, graph2 in itertools.product(graphs, repeat=2): 

654 parentDataIds = self.randomDataIds(n=1) 

655 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

656 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

657 (parentDataId,) = parentDataIds 

658 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

659 unioned = lhs.union(rhs) 

660 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

661 self.assertEqual(unioned.graph, graph1.union(graph2)) 

662 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

663 if unioned.hasFull(): 

664 self.assertEqual(unioned.subset(lhs.graph), lhs) 

665 self.assertEqual(unioned.subset(rhs.graph), rhs) 

666 if lhs.hasFull() and rhs.hasFull(): 

667 self.assertTrue(unioned.hasFull()) 

668 if lhs.graph >= unioned.graph and lhs.hasFull(): 

669 self.assertTrue(unioned.hasFull()) 

670 if lhs.hasRecords(): 

671 self.assertTrue(unioned.hasRecords()) 

672 if rhs.graph >= unioned.graph and rhs.hasFull(): 

673 self.assertTrue(unioned.hasFull()) 

674 if rhs.hasRecords(): 

675 self.assertTrue(unioned.hasRecords()) 

676 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

677 self.assertTrue(unioned.hasFull()) 

678 if lhs.hasRecords() and rhs.hasRecords(): 

679 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements: 

680 self.assertTrue(unioned.hasRecords()) 

681 

682 def testRegions(self): 

683 """Test that data IDs for a few known dimensions have the expected 

684 regions. 

685 """ 

686 for dataId in self.randomDataIds(n=4).subset( 

687 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

688 ): 

689 self.assertIsNotNone(dataId.region) 

690 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

691 self.assertEqual(dataId.region, dataId.records["visit"].region) 

692 for dataId in self.randomDataIds(n=4).subset( 

693 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

694 ): 

695 self.assertIsNotNone(dataId.region) 

696 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

697 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

698 for dataId in self.randomDataIds(n=4).subset( 

699 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

700 ): 

701 self.assertIsNotNone(dataId.region) 

702 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

703 self.assertEqual(dataId.region, dataId.records["tract"].region) 

704 for dataId in self.randomDataIds(n=4).subset( 

705 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

706 ): 

707 self.assertIsNotNone(dataId.region) 

708 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

709 self.assertEqual(dataId.region, dataId.records["patch"].region) 

710 

711 def testTimespans(self): 

712 """Test that data IDs for a few known dimensions have the expected 

713 timespans. 

714 """ 

715 for dataId in self.randomDataIds(n=4).subset( 

716 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

717 ): 

718 self.assertIsNotNone(dataId.timespan) 

719 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

720 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

721 # Also test the case for non-temporal DataIds. 

722 for dataId in self.randomDataIds(n=4).subset( 

723 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

724 ): 

725 self.assertIsNone(dataId.timespan) 

726 

727 def testIterableStatusFlags(self): 

728 """Test that DataCoordinateSet and DataCoordinateSequence compute 

729 their hasFull and hasRecords flags correctly from their elements. 

730 """ 

731 dataIds = self.randomDataIds(n=10) 

732 split = self.splitByStateFlags(dataIds) 

733 for cls in (DataCoordinateSet, DataCoordinateSequence): 

734 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

735 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

736 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

737 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

738 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

739 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

740 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

741 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

742 with self.assertRaises(ValueError): 

743 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

744 self.assertEqual( 

745 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

746 ) 

747 self.assertEqual( 

748 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

749 ) 

750 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

751 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

752 with self.assertRaises(ValueError): 

753 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

754 if dataIds.graph.implied: 

755 with self.assertRaises(ValueError): 

756 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

757 

758 def testSetOperations(self): 

759 """Test for self-consistency across DataCoordinateSet's operations.""" 

760 c = self.randomDataIds(n=10).toSet() 

761 a = self.randomDataIds(n=20).toSet() | c 

762 b = self.randomDataIds(n=20).toSet() | c 

763 # Make sure we don't have a particularly unlucky random seed, since 

764 # that would make a lot of this test uninteresting. 

765 self.assertNotEqual(a, b) 

766 self.assertGreater(len(a), 0) 

767 self.assertGreater(len(b), 0) 

768 # The rest of the tests should not depend on the random seed. 

769 self.assertEqual(a, a) 

770 self.assertNotEqual(a, a.toSequence()) 

771 self.assertEqual(a, a.toSequence().toSet()) 

772 self.assertEqual(a, a.toSequence().toSet()) 

773 self.assertEqual(b, b) 

774 self.assertNotEqual(b, b.toSequence()) 

775 self.assertEqual(b, b.toSequence().toSet()) 

776 self.assertEqual(a & b, a.intersection(b)) 

777 self.assertLessEqual(a & b, a) 

778 self.assertLessEqual(a & b, b) 

779 self.assertEqual(a | b, a.union(b)) 

780 self.assertGreaterEqual(a | b, a) 

781 self.assertGreaterEqual(a | b, b) 

782 self.assertEqual(a - b, a.difference(b)) 

783 self.assertLessEqual(a - b, a) 

784 self.assertLessEqual(b - a, b) 

785 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

786 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

787 

788 

789if __name__ == "__main__": 789 ↛ 790line 789 didn't jump to line 790, because the condition on line 789 was never true

790 unittest.main()