Coverage for tests/test_dimensions.py: 11%

371 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-07-03 01:08 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import os 

25import pickle 

26import unittest 

27from dataclasses import dataclass 

28from random import Random 

29from typing import Iterator, Optional 

30 

31from lsst.daf.butler import ( 

32 DataCoordinate, 

33 DataCoordinateSequence, 

34 DataCoordinateSet, 

35 Dimension, 

36 DimensionConfig, 

37 DimensionGraph, 

38 DimensionUniverse, 

39 NamedKeyDict, 

40 NamedValueSet, 

41 Registry, 

42 SpatialRegionDatabaseRepresentation, 

43 TimespanDatabaseRepresentation, 

44 YamlRepoImportBackend, 

45) 

46from lsst.daf.butler.registry import RegistryConfig 

47 

48DIMENSION_DATA_FILE = os.path.normpath( 

49 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

50) 

51 

52 

53def loadDimensionData() -> DataCoordinateSequence: 

54 """Load dimension data from an export file included in the code repository. 

55 

56 Returns 

57 ------- 

58 dataIds : `DataCoordinateSet` 

59 A set containing all data IDs in the export file. 

60 """ 

61 # Create an in-memory SQLite database and Registry just to import the YAML 

62 # data and retreive it as a set of DataCoordinate objects. 

63 config = RegistryConfig() 

64 config["db"] = "sqlite://" 

65 registry = Registry.createFromConfig(config) 

66 with open(DIMENSION_DATA_FILE, "r") as stream: 

67 backend = YamlRepoImportBackend(stream, registry) 

68 backend.register() 

69 backend.load(datastore=None) 

70 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

71 return registry.queryDataIds(dimensions).expanded().toSequence() 

72 

73 

74class DimensionTestCase(unittest.TestCase): 

75 """Tests for dimensions. 

76 

77 All tests here rely on the content of ``config/dimensions.yaml``, either 

78 to test that the definitions there are read in properly or just as generic 

79 data for testing various operations. 

80 """ 

81 

82 def setUp(self): 

83 self.universe = DimensionUniverse() 

84 

85 def checkGraphInvariants(self, graph): 

86 elements = list(graph.elements) 

87 for n, element in enumerate(elements): 

88 # Ordered comparisons on graphs behave like sets. 

89 self.assertLessEqual(element.graph, graph) 

90 # Ordered comparisons on elements correspond to the ordering within 

91 # a DimensionUniverse (topological, with deterministic 

92 # tiebreakers). 

93 for other in elements[:n]: 

94 self.assertLess(other, element) 

95 self.assertLessEqual(other, element) 

96 for other in elements[n + 1 :]: 

97 self.assertGreater(other, element) 

98 self.assertGreaterEqual(other, element) 

99 if isinstance(element, Dimension): 

100 self.assertEqual(element.graph.required, element.required) 

101 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

102 self.assertCountEqual( 

103 graph.required, 

104 [ 

105 dimension 

106 for dimension in graph.dimensions 

107 if not any(dimension in other.graph.implied for other in graph.elements) 

108 ], 

109 ) 

110 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

111 self.assertCountEqual( 

112 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

113 ) 

114 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

115 # Check primary key traversal order: each element should follow any it 

116 # requires, and element that is implied by any other in the graph 

117 # follow at least one of those. 

118 seen = NamedValueSet() 

119 for element in graph.primaryKeyTraversalOrder: 

120 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

121 seen.add(element) 

122 self.assertLessEqual(element.graph.required, seen) 

123 if element in graph.implied: 

124 self.assertTrue(any(element in s.implied for s in seen)) 

125 self.assertCountEqual(seen, graph.elements) 

126 

127 def testConfigPresent(self): 

128 config = self.universe.dimensionConfig 

129 self.assertIsInstance(config, DimensionConfig) 

130 

131 def testVersion(self): 

132 self.assertEqual(self.universe.namespace, "daf_butler") 

133 # Test was added starting at version 2. 

134 self.assertGreaterEqual(self.universe.version, 2) 

135 

136 def testConfigRead(self): 

137 self.assertEqual( 

138 set(self.universe.getStaticDimensions().names), 

139 { 

140 "instrument", 

141 "visit", 

142 "visit_system", 

143 "exposure", 

144 "detector", 

145 "physical_filter", 

146 "band", 

147 "subfilter", 

148 "skymap", 

149 "tract", 

150 "patch", 

151 } 

152 | {f"htm{level}" for level in range(25)} 

153 | {f"healpix{level}" for level in range(18)}, 

154 ) 

155 

156 def testGraphs(self): 

157 self.checkGraphInvariants(self.universe.empty) 

158 for element in self.universe.getStaticElements(): 

159 self.checkGraphInvariants(element.graph) 

160 

161 def testInstrumentDimensions(self): 

162 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

163 self.assertCountEqual( 

164 graph.dimensions.names, 

165 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

166 ) 

167 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

168 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

169 self.assertCountEqual( 

170 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

171 ) 

172 self.assertCountEqual(graph.governors.names, {"instrument"}) 

173 

174 def testCalibrationDimensions(self): 

175 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

176 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

177 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

178 self.assertCountEqual(graph.implied.names, ("band",)) 

179 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

180 self.assertCountEqual(graph.governors.names, {"instrument"}) 

181 

182 def testObservationDimensions(self): 

183 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

184 self.assertCountEqual( 

185 graph.dimensions.names, 

186 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

187 ) 

188 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

189 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

190 self.assertCountEqual( 

191 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

192 ) 

193 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

194 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

195 self.assertCountEqual(graph.governors.names, {"instrument"}) 

196 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

197 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

198 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

199 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

200 

201 def testSkyMapDimensions(self): 

202 graph = DimensionGraph(self.universe, names=("patch",)) 

203 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

204 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

205 self.assertCountEqual(graph.implied.names, ()) 

206 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

207 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

208 self.assertCountEqual(graph.governors.names, {"skymap"}) 

209 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

210 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

211 

212 def testSubsetCalculation(self): 

213 """Test that independent spatial and temporal options are computed 

214 correctly. 

215 """ 

216 graph = DimensionGraph( 

217 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

218 ) 

219 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

220 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

221 

222 def testSchemaGeneration(self): 

223 tableSpecs = NamedKeyDict({}) 

224 for element in self.universe.getStaticElements(): 

225 if element.hasTable and element.viewOf is None: 

226 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

227 RegionReprClass=SpatialRegionDatabaseRepresentation, 

228 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

229 ) 

230 for element, tableSpec in tableSpecs.items(): 

231 for dep in element.required: 

232 with self.subTest(element=element.name, dep=dep.name): 

233 if dep != element: 

234 self.assertIn(dep.name, tableSpec.fields) 

235 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

236 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

237 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

238 self.assertFalse(tableSpec.fields[dep.name].nullable) 

239 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

240 else: 

241 self.assertIn(element.primaryKey.name, tableSpec.fields) 

242 self.assertEqual( 

243 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

244 ) 

245 self.assertEqual( 

246 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

247 ) 

248 self.assertEqual( 

249 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

250 ) 

251 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

252 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

253 for dep in element.implied: 

254 with self.subTest(element=element.name, dep=dep.name): 

255 self.assertIn(dep.name, tableSpec.fields) 

256 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

257 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

258 for foreignKey in tableSpec.foreignKeys: 

259 self.assertIn(foreignKey.table, tableSpecs) 

260 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

261 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

262 for source, target in zip(foreignKey.source, foreignKey.target): 

263 self.assertIn(source, tableSpec.fields.names) 

264 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

265 self.assertEqual( 

266 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

267 ) 

268 self.assertEqual( 

269 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

270 ) 

271 self.assertEqual( 

272 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

273 ) 

274 

275 def testPickling(self): 

276 # Pickling and copying should always yield the exact same object within 

277 # a single process (cross-process is impossible to test here). 

278 universe1 = DimensionUniverse() 

279 universe2 = pickle.loads(pickle.dumps(universe1)) 

280 universe3 = copy.copy(universe1) 

281 universe4 = copy.deepcopy(universe1) 

282 self.assertIs(universe1, universe2) 

283 self.assertIs(universe1, universe3) 

284 self.assertIs(universe1, universe4) 

285 for element1 in universe1.getStaticElements(): 

286 element2 = pickle.loads(pickle.dumps(element1)) 

287 self.assertIs(element1, element2) 

288 graph1 = element1.graph 

289 graph2 = pickle.loads(pickle.dumps(graph1)) 

290 self.assertIs(graph1, graph2) 

291 

292 

293@dataclass 

294class SplitByStateFlags: 

295 """A struct that separates data IDs with different states but the same 

296 values. 

297 """ 

298 

299 minimal: Optional[DataCoordinateSequence] = None 

300 """Data IDs that only contain values for required dimensions. 

301 

302 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

303 if ``minimal.graph.implied`` has no elements. 

304 `DataCoordinate.hasRecords()` will always return `False`. 

305 """ 

306 

307 complete: Optional[DataCoordinateSequence] = None 

308 """Data IDs that contain values for all dimensions. 

309 

310 `DataCoordinateSequence.hasFull()` will always `True` and 

311 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

312 """ 

313 

314 expanded: Optional[DataCoordinateSequence] = None 

315 """Data IDs that contain values for all dimensions as well as records. 

316 

317 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

318 always return `True` for this attribute. 

319 """ 

320 

321 def chain(self, n: Optional[int] = None) -> Iterator: 

322 """Iterate over the data IDs of different types. 

323 

324 Parameters 

325 ---------- 

326 n : `int`, optional 

327 If provided (`None` is default), iterate over only the ``nth`` 

328 data ID in each attribute. 

329 

330 Yields 

331 ------ 

332 dataId : `DataCoordinate` 

333 A data ID from one of the attributes in this struct. 

334 """ 

335 if n is None: 

336 s = slice(None, None) 

337 else: 

338 s = slice(n, n + 1) 

339 if self.minimal is not None: 

340 yield from self.minimal[s] 

341 if self.complete is not None: 

342 yield from self.complete[s] 

343 if self.expanded is not None: 

344 yield from self.expanded[s] 

345 

346 

347class DataCoordinateTestCase(unittest.TestCase): 

348 

349 RANDOM_SEED = 10 

350 

351 @classmethod 

352 def setUpClass(cls): 

353 cls.allDataIds = loadDimensionData() 

354 

355 def setUp(self): 

356 self.rng = Random(self.RANDOM_SEED) 

357 

358 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None): 

359 """Select random data IDs from those loaded from test data. 

360 

361 Parameters 

362 ---------- 

363 n : `int` 

364 Number of data IDs to select. 

365 dataIds : `DataCoordinateSequence`, optional 

366 Data IDs to select from. Defaults to ``self.allDataIds``. 

367 

368 Returns 

369 ------- 

370 selected : `DataCoordinateSequence` 

371 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

372 """ 

373 if dataIds is None: 

374 dataIds = self.allDataIds 

375 return DataCoordinateSequence( 

376 self.rng.sample(dataIds, n), 

377 graph=dataIds.graph, 

378 hasFull=dataIds.hasFull(), 

379 hasRecords=dataIds.hasRecords(), 

380 check=False, 

381 ) 

382 

383 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph: 

384 """Generate a random `DimensionGraph` that has a subset of the 

385 dimensions in a given one. 

386 

387 Parameters 

388 ---------- 

389 n : `int` 

390 Number of dimensions to select, before automatic expansion by 

391 `DimensionGraph`. 

392 dataIds : `DimensionGraph`, optional 

393 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

394 

395 Returns 

396 ------- 

397 selected : `DimensionGraph` 

398 ``n`` or more dimensions randomly selected from ``graph`` with 

399 replacement. 

400 """ 

401 if graph is None: 

402 graph = self.allDataIds.graph 

403 return DimensionGraph( 

404 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

405 ) 

406 

407 def splitByStateFlags( 

408 self, 

409 dataIds: Optional[DataCoordinateSequence] = None, 

410 *, 

411 expanded: bool = True, 

412 complete: bool = True, 

413 minimal: bool = True, 

414 ) -> SplitByStateFlags: 

415 """Given a sequence of data IDs, generate new equivalent sequences 

416 containing less information. 

417 

418 Parameters 

419 ---------- 

420 dataIds : `DataCoordinateSequence`, optional. 

421 Data IDs to start from. Defaults to ``self.allDataIds``. 

422 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

423 `True`. 

424 expanded : `bool`, optional 

425 If `True` (default) include the original data IDs that contain all 

426 information in the result. 

427 complete : `bool`, optional 

428 If `True` (default) include data IDs for which ``hasFull()`` 

429 returns `True` but ``hasRecords()`` does not. 

430 minimal : `bool`, optional 

431 If `True` (default) include data IDS that only contain values for 

432 required dimensions, for which ``hasFull()`` may not return `True`. 

433 

434 Returns 

435 ------- 

436 split : `SplitByStateFlags` 

437 A dataclass holding the indicated data IDs in attributes that 

438 correspond to the boolean keyword arguments. 

439 """ 

440 if dataIds is None: 

441 dataIds = self.allDataIds 

442 assert dataIds.hasFull() and dataIds.hasRecords() 

443 result = SplitByStateFlags(expanded=dataIds) 

444 if complete: 

445 result.complete = DataCoordinateSequence( 

446 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

447 graph=dataIds.graph, 

448 ) 

449 self.assertTrue(result.complete.hasFull()) 

450 self.assertFalse(result.complete.hasRecords()) 

451 if minimal: 

452 result.minimal = DataCoordinateSequence( 

453 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

454 graph=dataIds.graph, 

455 ) 

456 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

457 self.assertFalse(result.minimal.hasRecords()) 

458 if not expanded: 

459 result.expanded = None 

460 return result 

461 

462 def testMappingInterface(self): 

463 """Test that the mapping interface in `DataCoordinate` and (when 

464 applicable) its ``full`` property are self-consistent and consistent 

465 with the ``graph`` property. 

466 """ 

467 for n in range(5): 

468 dimensions = self.randomDimensionSubset() 

469 dataIds = self.randomDataIds(n=1).subset(dimensions) 

470 split = self.splitByStateFlags(dataIds) 

471 for dataId in split.chain(): 

472 with self.subTest(dataId=dataId): 

473 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()]) 

474 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()]) 

475 self.assertEqual(dataId.keys(), dataId.graph.required) 

476 for dataId in itertools.chain(split.complete, split.expanded): 

477 with self.subTest(dataId=dataId): 

478 self.assertTrue(dataId.hasFull()) 

479 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

480 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

481 

482 def testEquality(self): 

483 """Test that different `DataCoordinate` instances with different state 

484 flags can be compared with each other and other mappings. 

485 """ 

486 dataIds = self.randomDataIds(n=2) 

487 split = self.splitByStateFlags(dataIds) 

488 # Iterate over all combinations of different states of DataCoordinate, 

489 # with the same underlying data ID values. 

490 for a0, b0 in itertools.combinations(split.chain(0), 2): 

491 self.assertEqual(a0, b0) 

492 self.assertEqual(a0, b0.byName()) 

493 self.assertEqual(a0.byName(), b0) 

494 # Same thing, for a different data ID value. 

495 for a1, b1 in itertools.combinations(split.chain(1), 2): 

496 self.assertEqual(a1, b1) 

497 self.assertEqual(a1, b1.byName()) 

498 self.assertEqual(a1.byName(), b1) 

499 # Iterate over all combinations of different states of DataCoordinate, 

500 # with different underlying data ID values. 

501 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

502 self.assertNotEqual(a0, b1) 

503 self.assertNotEqual(a1, b0) 

504 self.assertNotEqual(a0, b1.byName()) 

505 self.assertNotEqual(a0.byName(), b1) 

506 self.assertNotEqual(a1, b0.byName()) 

507 self.assertNotEqual(a1.byName(), b0) 

508 

509 def testStandardize(self): 

510 """Test constructing a DataCoordinate from many different kinds of 

511 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

512 """ 

513 for n in range(5): 

514 dimensions = self.randomDimensionSubset() 

515 dataIds = self.randomDataIds(n=1).subset(dimensions) 

516 split = self.splitByStateFlags(dataIds) 

517 for m, dataId in enumerate(split.chain()): 

518 # Passing in any kind of DataCoordinate alone just returns 

519 # that object. 

520 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

521 # Same if we also explicitly pass the dimensions we want. 

522 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

523 # Same if we pass the dimensions and some irrelevant 

524 # kwargs. 

525 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

526 # Test constructing a new data ID from this one with a 

527 # subset of the dimensions. 

528 # This is not possible for some combinations of 

529 # dimensions if hasFull is False (see 

530 # `DataCoordinate.subset` docs). 

531 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

532 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

533 newDataIds = [ 

534 dataId.subset(newDimensions), 

535 DataCoordinate.standardize(dataId, graph=newDimensions), 

536 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

537 ] 

538 for newDataId in newDataIds: 

539 with self.subTest(newDataId=newDataId, type=type(dataId)): 

540 commonKeys = dataId.keys() & newDataId.keys() 

541 self.assertTrue(commonKeys) 

542 self.assertEqual( 

543 [newDataId[k] for k in commonKeys], 

544 [dataId[k] for k in commonKeys], 

545 ) 

546 # This should never "downgrade" from 

547 # Complete to Minimal or Expanded to Complete. 

548 if dataId.hasRecords(): 

549 self.assertTrue(newDataId.hasRecords()) 

550 if dataId.hasFull(): 

551 self.assertTrue(newDataId.hasFull()) 

552 # Start from a complete data ID, and pass its values in via several 

553 # different ways that should be equivalent. 

554 for dataId in split.complete: 

555 # Split the keys (dimension names) into two random subsets, so 

556 # we can pass some as kwargs below. 

557 keys1 = set( 

558 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

559 ) 

560 keys2 = dataId.graph.dimensions.names - keys1 

561 newCompleteDataIds = [ 

562 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

563 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

564 DataCoordinate.standardize( 

565 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

566 ), 

567 DataCoordinate.standardize( 

568 DataCoordinate.makeEmpty(dataId.graph.universe), 

569 graph=dataId.graph, 

570 **dataId.full.byName(), 

571 ), 

572 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

573 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

574 DataCoordinate.standardize( 

575 {k: dataId[k] for k in keys1}, 

576 universe=dataId.universe, 

577 **{k: dataId[k] for k in keys2}, 

578 ), 

579 DataCoordinate.standardize( 

580 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

581 ), 

582 ] 

583 for newDataId in newCompleteDataIds: 

584 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

585 self.assertEqual(dataId, newDataId) 

586 self.assertTrue(newDataId.hasFull()) 

587 

588 def testUnion(self): 

589 """Test `DataCoordinate.union`.""" 

590 # Make test graphs to combine; mostly random, but with a few explicit 

591 # cases to make sure certain edge cases are covered. 

592 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

593 graphs.append(self.allDataIds.universe["visit"].graph) 

594 graphs.append(self.allDataIds.universe["detector"].graph) 

595 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

596 graphs.append(self.allDataIds.universe["band"].graph) 

597 # Iterate over all combinations, including the same graph with itself. 

598 for graph1, graph2 in itertools.product(graphs, repeat=2): 

599 parentDataIds = self.randomDataIds(n=1) 

600 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

601 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

602 (parentDataId,) = parentDataIds 

603 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

604 unioned = lhs.union(rhs) 

605 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

606 self.assertEqual(unioned.graph, graph1.union(graph2)) 

607 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

608 if unioned.hasFull(): 

609 self.assertEqual(unioned.subset(lhs.graph), lhs) 

610 self.assertEqual(unioned.subset(rhs.graph), rhs) 

611 if lhs.hasFull() and rhs.hasFull(): 

612 self.assertTrue(unioned.hasFull()) 

613 if lhs.graph >= unioned.graph and lhs.hasFull(): 

614 self.assertTrue(unioned.hasFull()) 

615 if lhs.hasRecords(): 

616 self.assertTrue(unioned.hasRecords()) 

617 if rhs.graph >= unioned.graph and rhs.hasFull(): 

618 self.assertTrue(unioned.hasFull()) 

619 if rhs.hasRecords(): 

620 self.assertTrue(unioned.hasRecords()) 

621 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

622 self.assertTrue(unioned.hasFull()) 

623 if lhs.hasRecords() and rhs.hasRecords(): 

624 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements: 

625 self.assertTrue(unioned.hasRecords()) 

626 

627 def testRegions(self): 

628 """Test that data IDs for a few known dimensions have the expected 

629 regions. 

630 """ 

631 for dataId in self.randomDataIds(n=4).subset( 

632 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

633 ): 

634 self.assertIsNotNone(dataId.region) 

635 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

636 self.assertEqual(dataId.region, dataId.records["visit"].region) 

637 for dataId in self.randomDataIds(n=4).subset( 

638 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

639 ): 

640 self.assertIsNotNone(dataId.region) 

641 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

642 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

643 for dataId in self.randomDataIds(n=4).subset( 

644 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

645 ): 

646 self.assertIsNotNone(dataId.region) 

647 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

648 self.assertEqual(dataId.region, dataId.records["tract"].region) 

649 for dataId in self.randomDataIds(n=4).subset( 

650 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

651 ): 

652 self.assertIsNotNone(dataId.region) 

653 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

654 self.assertEqual(dataId.region, dataId.records["patch"].region) 

655 

656 def testTimespans(self): 

657 """Test that data IDs for a few known dimensions have the expected 

658 timespans. 

659 """ 

660 for dataId in self.randomDataIds(n=4).subset( 

661 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

662 ): 

663 self.assertIsNotNone(dataId.timespan) 

664 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

665 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

666 # Also test the case for non-temporal DataIds. 

667 for dataId in self.randomDataIds(n=4).subset( 

668 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

669 ): 

670 self.assertIsNone(dataId.timespan) 

671 

672 def testIterableStatusFlags(self): 

673 """Test that DataCoordinateSet and DataCoordinateSequence compute 

674 their hasFull and hasRecords flags correctly from their elements. 

675 """ 

676 dataIds = self.randomDataIds(n=10) 

677 split = self.splitByStateFlags(dataIds) 

678 for cls in (DataCoordinateSet, DataCoordinateSequence): 

679 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

680 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

681 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

682 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

683 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

684 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

685 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

686 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

687 with self.assertRaises(ValueError): 

688 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

689 self.assertEqual( 

690 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

691 ) 

692 self.assertEqual( 

693 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

694 ) 

695 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

696 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

697 with self.assertRaises(ValueError): 

698 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

699 if dataIds.graph.implied: 

700 with self.assertRaises(ValueError): 

701 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

702 

703 def testSetOperations(self): 

704 """Test for self-consistency across DataCoordinateSet's operations.""" 

705 c = self.randomDataIds(n=10).toSet() 

706 a = self.randomDataIds(n=20).toSet() | c 

707 b = self.randomDataIds(n=20).toSet() | c 

708 # Make sure we don't have a particularly unlucky random seed, since 

709 # that would make a lot of this test uninteresting. 

710 self.assertNotEqual(a, b) 

711 self.assertGreater(len(a), 0) 

712 self.assertGreater(len(b), 0) 

713 # The rest of the tests should not depend on the random seed. 

714 self.assertEqual(a, a) 

715 self.assertNotEqual(a, a.toSequence()) 

716 self.assertEqual(a, a.toSequence().toSet()) 

717 self.assertEqual(a, a.toSequence().toSet()) 

718 self.assertEqual(b, b) 

719 self.assertNotEqual(b, b.toSequence()) 

720 self.assertEqual(b, b.toSequence().toSet()) 

721 self.assertEqual(a & b, a.intersection(b)) 

722 self.assertLessEqual(a & b, a) 

723 self.assertLessEqual(a & b, b) 

724 self.assertEqual(a | b, a.union(b)) 

725 self.assertGreaterEqual(a | b, a) 

726 self.assertGreaterEqual(a | b, b) 

727 self.assertEqual(a - b, a.difference(b)) 

728 self.assertLessEqual(a - b, a) 

729 self.assertLessEqual(b - a, b) 

730 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

731 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

732 

733 

734if __name__ == "__main__": 734 ↛ 735line 734 didn't jump to line 735, because the condition on line 734 was never true

735 unittest.main()