Coverage for tests/test_dimensions.py: 11%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

365 statements  

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import os 

25import pickle 

26import unittest 

27from dataclasses import dataclass 

28from random import Random 

29from typing import Iterator, Optional 

30 

31from lsst.daf.butler import ( 

32 DataCoordinate, 

33 DataCoordinateSequence, 

34 DataCoordinateSet, 

35 Dimension, 

36 DimensionConfig, 

37 DimensionGraph, 

38 DimensionUniverse, 

39 NamedKeyDict, 

40 NamedValueSet, 

41 Registry, 

42 SpatialRegionDatabaseRepresentation, 

43 TimespanDatabaseRepresentation, 

44 YamlRepoImportBackend, 

45) 

46from lsst.daf.butler.registry import RegistryConfig 

47 

48DIMENSION_DATA_FILE = os.path.normpath( 

49 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

50) 

51 

52 

53def loadDimensionData() -> DataCoordinateSequence: 

54 """Load dimension data from an export file included in the code repository. 

55 

56 Returns 

57 ------- 

58 dataIds : `DataCoordinateSet` 

59 A set containing all data IDs in the export file. 

60 """ 

61 # Create an in-memory SQLite database and Registry just to import the YAML 

62 # data and retreive it as a set of DataCoordinate objects. 

63 config = RegistryConfig() 

64 config["db"] = "sqlite://" 

65 registry = Registry.createFromConfig(config) 

66 with open(DIMENSION_DATA_FILE, "r") as stream: 

67 backend = YamlRepoImportBackend(stream, registry) 

68 backend.register() 

69 backend.load(datastore=None) 

70 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

71 return registry.queryDataIds(dimensions).expanded().toSequence() 

72 

73 

74class DimensionTestCase(unittest.TestCase): 

75 """Tests for dimensions. 

76 

77 All tests here rely on the content of ``config/dimensions.yaml``, either 

78 to test that the definitions there are read in properly or just as generic 

79 data for testing various operations. 

80 """ 

81 

82 def setUp(self): 

83 self.universe = DimensionUniverse() 

84 

85 def checkGraphInvariants(self, graph): 

86 elements = list(graph.elements) 

87 for n, element in enumerate(elements): 

88 # Ordered comparisons on graphs behave like sets. 

89 self.assertLessEqual(element.graph, graph) 

90 # Ordered comparisons on elements correspond to the ordering within 

91 # a DimensionUniverse (topological, with deterministic 

92 # tiebreakers). 

93 for other in elements[:n]: 

94 self.assertLess(other, element) 

95 self.assertLessEqual(other, element) 

96 for other in elements[n + 1 :]: 

97 self.assertGreater(other, element) 

98 self.assertGreaterEqual(other, element) 

99 if isinstance(element, Dimension): 

100 self.assertEqual(element.graph.required, element.required) 

101 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

102 self.assertCountEqual( 

103 graph.required, 

104 [ 

105 dimension 

106 for dimension in graph.dimensions 

107 if not any(dimension in other.graph.implied for other in graph.elements) 

108 ], 

109 ) 

110 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

111 self.assertCountEqual( 

112 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

113 ) 

114 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

115 # Check primary key traversal order: each element should follow any it 

116 # requires, and element that is implied by any other in the graph 

117 # follow at least one of those. 

118 seen = NamedValueSet() 

119 for element in graph.primaryKeyTraversalOrder: 

120 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

121 seen.add(element) 

122 self.assertLessEqual(element.graph.required, seen) 

123 if element in graph.implied: 

124 self.assertTrue(any(element in s.implied for s in seen)) 

125 self.assertCountEqual(seen, graph.elements) 

126 

127 def testConfigPresent(self): 

128 config = self.universe.dimensionConfig 

129 self.assertIsInstance(config, DimensionConfig) 

130 

131 def testConfigRead(self): 

132 self.assertEqual( 

133 set(self.universe.getStaticDimensions().names), 

134 { 

135 "instrument", 

136 "visit", 

137 "visit_system", 

138 "exposure", 

139 "detector", 

140 "physical_filter", 

141 "band", 

142 "subfilter", 

143 "skymap", 

144 "tract", 

145 "patch", 

146 } 

147 | {f"htm{level}" for level in range(25)} 

148 | {f"healpix{level}" for level in range(18)}, 

149 ) 

150 

151 def testGraphs(self): 

152 self.checkGraphInvariants(self.universe.empty) 

153 for element in self.universe.getStaticElements(): 

154 self.checkGraphInvariants(element.graph) 

155 

156 def testInstrumentDimensions(self): 

157 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

158 self.assertCountEqual( 

159 graph.dimensions.names, 

160 ("instrument", "exposure", "detector", "visit", "physical_filter", "band", "visit_system"), 

161 ) 

162 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

163 self.assertCountEqual(graph.implied.names, ("physical_filter", "band", "visit_system")) 

164 self.assertCountEqual( 

165 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

166 ) 

167 self.assertCountEqual(graph.governors.names, {"instrument"}) 

168 

169 def testCalibrationDimensions(self): 

170 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

171 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

172 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

173 self.assertCountEqual(graph.implied.names, ("band",)) 

174 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

175 self.assertCountEqual(graph.governors.names, {"instrument"}) 

176 

177 def testObservationDimensions(self): 

178 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

179 self.assertCountEqual( 

180 graph.dimensions.names, 

181 ("instrument", "detector", "visit", "exposure", "physical_filter", "band", "visit_system"), 

182 ) 

183 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

184 self.assertCountEqual(graph.implied.names, ("physical_filter", "band", "visit_system")) 

185 self.assertCountEqual( 

186 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

187 ) 

188 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

189 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

190 self.assertCountEqual(graph.governors.names, {"instrument"}) 

191 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

192 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

193 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

194 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

195 

196 def testSkyMapDimensions(self): 

197 graph = DimensionGraph(self.universe, names=("patch",)) 

198 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

199 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

200 self.assertCountEqual(graph.implied.names, ()) 

201 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

202 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

203 self.assertCountEqual(graph.governors.names, {"skymap"}) 

204 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

205 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

206 

207 def testSubsetCalculation(self): 

208 """Test that independent spatial and temporal options are computed 

209 correctly. 

210 """ 

211 graph = DimensionGraph( 

212 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

213 ) 

214 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

215 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

216 

217 def testSchemaGeneration(self): 

218 tableSpecs = NamedKeyDict({}) 

219 for element in self.universe.getStaticElements(): 

220 if element.hasTable and element.viewOf is None: 

221 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

222 RegionReprClass=SpatialRegionDatabaseRepresentation, 

223 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

224 ) 

225 for element, tableSpec in tableSpecs.items(): 

226 for dep in element.required: 

227 with self.subTest(element=element.name, dep=dep.name): 

228 if dep != element: 

229 self.assertIn(dep.name, tableSpec.fields) 

230 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

231 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

232 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

233 self.assertFalse(tableSpec.fields[dep.name].nullable) 

234 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

235 else: 

236 self.assertIn(element.primaryKey.name, tableSpec.fields) 

237 self.assertEqual( 

238 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

239 ) 

240 self.assertEqual( 

241 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

242 ) 

243 self.assertEqual( 

244 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

245 ) 

246 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

247 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

248 for dep in element.implied: 

249 with self.subTest(element=element.name, dep=dep.name): 

250 self.assertIn(dep.name, tableSpec.fields) 

251 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

252 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

253 for foreignKey in tableSpec.foreignKeys: 

254 self.assertIn(foreignKey.table, tableSpecs) 

255 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

256 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

257 for source, target in zip(foreignKey.source, foreignKey.target): 

258 self.assertIn(source, tableSpec.fields.names) 

259 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

260 self.assertEqual( 

261 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

262 ) 

263 self.assertEqual( 

264 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

265 ) 

266 self.assertEqual( 

267 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

268 ) 

269 

270 def testPickling(self): 

271 # Pickling and copying should always yield the exact same object within 

272 # a single process (cross-process is impossible to test here). 

273 universe1 = DimensionUniverse() 

274 universe2 = pickle.loads(pickle.dumps(universe1)) 

275 universe3 = copy.copy(universe1) 

276 universe4 = copy.deepcopy(universe1) 

277 self.assertIs(universe1, universe2) 

278 self.assertIs(universe1, universe3) 

279 self.assertIs(universe1, universe4) 

280 for element1 in universe1.getStaticElements(): 

281 element2 = pickle.loads(pickle.dumps(element1)) 

282 self.assertIs(element1, element2) 

283 graph1 = element1.graph 

284 graph2 = pickle.loads(pickle.dumps(graph1)) 

285 self.assertIs(graph1, graph2) 

286 

287 

288@dataclass 

289class SplitByStateFlags: 

290 """A struct that separates data IDs with different states but the same 

291 values. 

292 """ 

293 

294 minimal: Optional[DataCoordinateSequence] = None 

295 """Data IDs that only contain values for required dimensions. 

296 

297 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

298 if ``minimal.graph.implied`` has no elements. 

299 `DataCoordinate.hasRecords()` will always return `False`. 

300 """ 

301 

302 complete: Optional[DataCoordinateSequence] = None 

303 """Data IDs that contain values for all dimensions. 

304 

305 `DataCoordinateSequence.hasFull()` will always `True` and 

306 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

307 """ 

308 

309 expanded: Optional[DataCoordinateSequence] = None 

310 """Data IDs that contain values for all dimensions as well as records. 

311 

312 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

313 always return `True` for this attribute. 

314 """ 

315 

316 def chain(self, n: Optional[int] = None) -> Iterator: 

317 """Iterate over the data IDs of different types. 

318 

319 Parameters 

320 ---------- 

321 n : `int`, optional 

322 If provided (`None` is default), iterate over only the ``nth`` 

323 data ID in each attribute. 

324 

325 Yields 

326 ------ 

327 dataId : `DataCoordinate` 

328 A data ID from one of the attributes in this struct. 

329 """ 

330 if n is None: 

331 s = slice(None, None) 

332 else: 

333 s = slice(n, n + 1) 

334 if self.minimal is not None: 

335 yield from self.minimal[s] 

336 if self.complete is not None: 

337 yield from self.complete[s] 

338 if self.expanded is not None: 

339 yield from self.expanded[s] 

340 

341 

342class DataCoordinateTestCase(unittest.TestCase): 

343 

344 RANDOM_SEED = 10 

345 

346 @classmethod 

347 def setUpClass(cls): 

348 cls.allDataIds = loadDimensionData() 

349 

350 def setUp(self): 

351 self.rng = Random(self.RANDOM_SEED) 

352 

353 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None): 

354 """Select random data IDs from those loaded from test data. 

355 

356 Parameters 

357 ---------- 

358 n : `int` 

359 Number of data IDs to select. 

360 dataIds : `DataCoordinateSequence`, optional 

361 Data IDs to select from. Defaults to ``self.allDataIds``. 

362 

363 Returns 

364 ------- 

365 selected : `DataCoordinateSequence` 

366 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

367 """ 

368 if dataIds is None: 

369 dataIds = self.allDataIds 

370 return DataCoordinateSequence( 

371 self.rng.sample(dataIds, n), 

372 graph=dataIds.graph, 

373 hasFull=dataIds.hasFull(), 

374 hasRecords=dataIds.hasRecords(), 

375 check=False, 

376 ) 

377 

378 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph: 

379 """Generate a random `DimensionGraph` that has a subset of the 

380 dimensions in a given one. 

381 

382 Parameters 

383 ---------- 

384 n : `int` 

385 Number of dimensions to select, before automatic expansion by 

386 `DimensionGraph`. 

387 dataIds : `DimensionGraph`, optional 

388 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

389 

390 Returns 

391 ------- 

392 selected : `DimensionGraph` 

393 ``n`` or more dimensions randomly selected from ``graph`` with 

394 replacement. 

395 """ 

396 if graph is None: 

397 graph = self.allDataIds.graph 

398 return DimensionGraph( 

399 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

400 ) 

401 

402 def splitByStateFlags( 

403 self, 

404 dataIds: Optional[DataCoordinateSequence] = None, 

405 *, 

406 expanded: bool = True, 

407 complete: bool = True, 

408 minimal: bool = True, 

409 ) -> SplitByStateFlags: 

410 """Given a sequence of data IDs, generate new equivalent sequences 

411 containing less information. 

412 

413 Parameters 

414 ---------- 

415 dataIds : `DataCoordinateSequence`, optional. 

416 Data IDs to start from. Defaults to ``self.allDataIds``. 

417 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

418 `True`. 

419 expanded : `bool`, optional 

420 If `True` (default) include the original data IDs that contain all 

421 information in the result. 

422 complete : `bool`, optional 

423 If `True` (default) include data IDs for which ``hasFull()`` 

424 returns `True` but ``hasRecords()`` does not. 

425 minimal : `bool`, optional 

426 If `True` (default) include data IDS that only contain values for 

427 required dimensions, for which ``hasFull()`` may not return `True`. 

428 

429 Returns 

430 ------- 

431 split : `SplitByStateFlags` 

432 A dataclass holding the indicated data IDs in attributes that 

433 correspond to the boolean keyword arguments. 

434 """ 

435 if dataIds is None: 

436 dataIds = self.allDataIds 

437 assert dataIds.hasFull() and dataIds.hasRecords() 

438 result = SplitByStateFlags(expanded=dataIds) 

439 if complete: 

440 result.complete = DataCoordinateSequence( 

441 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

442 graph=dataIds.graph, 

443 ) 

444 self.assertTrue(result.complete.hasFull()) 

445 self.assertFalse(result.complete.hasRecords()) 

446 if minimal: 

447 result.minimal = DataCoordinateSequence( 

448 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

449 graph=dataIds.graph, 

450 ) 

451 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

452 self.assertFalse(result.minimal.hasRecords()) 

453 if not expanded: 

454 result.expanded = None 

455 return result 

456 

457 def testMappingInterface(self): 

458 """Test that the mapping interface in `DataCoordinate` and (when 

459 applicable) its ``full`` property are self-consistent and consistent 

460 with the ``graph`` property. 

461 """ 

462 for n in range(5): 

463 dimensions = self.randomDimensionSubset() 

464 dataIds = self.randomDataIds(n=1).subset(dimensions) 

465 split = self.splitByStateFlags(dataIds) 

466 for dataId in split.chain(): 

467 with self.subTest(dataId=dataId): 

468 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()]) 

469 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()]) 

470 self.assertEqual(dataId.keys(), dataId.graph.required) 

471 for dataId in itertools.chain(split.complete, split.expanded): 

472 with self.subTest(dataId=dataId): 

473 self.assertTrue(dataId.hasFull()) 

474 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

475 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

476 

477 def testEquality(self): 

478 """Test that different `DataCoordinate` instances with different state 

479 flags can be compared with each other and other mappings. 

480 """ 

481 dataIds = self.randomDataIds(n=2) 

482 split = self.splitByStateFlags(dataIds) 

483 # Iterate over all combinations of different states of DataCoordinate, 

484 # with the same underlying data ID values. 

485 for a0, b0 in itertools.combinations(split.chain(0), 2): 

486 self.assertEqual(a0, b0) 

487 self.assertEqual(a0, b0.byName()) 

488 self.assertEqual(a0.byName(), b0) 

489 # Same thing, for a different data ID value. 

490 for a1, b1 in itertools.combinations(split.chain(1), 2): 

491 self.assertEqual(a1, b1) 

492 self.assertEqual(a1, b1.byName()) 

493 self.assertEqual(a1.byName(), b1) 

494 # Iterate over all combinations of different states of DataCoordinate, 

495 # with different underlying data ID values. 

496 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

497 self.assertNotEqual(a0, b1) 

498 self.assertNotEqual(a1, b0) 

499 self.assertNotEqual(a0, b1.byName()) 

500 self.assertNotEqual(a0.byName(), b1) 

501 self.assertNotEqual(a1, b0.byName()) 

502 self.assertNotEqual(a1.byName(), b0) 

503 

504 def testStandardize(self): 

505 """Test constructing a DataCoordinate from many different kinds of 

506 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

507 """ 

508 for n in range(5): 

509 dimensions = self.randomDimensionSubset() 

510 dataIds = self.randomDataIds(n=1).subset(dimensions) 

511 split = self.splitByStateFlags(dataIds) 

512 for m, dataId in enumerate(split.chain()): 

513 # Passing in any kind of DataCoordinate alone just returns 

514 # that object. 

515 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

516 # Same if we also explicitly pass the dimensions we want. 

517 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

518 # Same if we pass the dimensions and some irrelevant 

519 # kwargs. 

520 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

521 # Test constructing a new data ID from this one with a 

522 # subset of the dimensions. 

523 # This is not possible for some combinations of 

524 # dimensions if hasFull is False (see 

525 # `DataCoordinate.subset` docs). 

526 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

527 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

528 newDataIds = [ 

529 dataId.subset(newDimensions), 

530 DataCoordinate.standardize(dataId, graph=newDimensions), 

531 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

532 ] 

533 for newDataId in newDataIds: 

534 with self.subTest(newDataId=newDataId, type=type(dataId)): 

535 commonKeys = dataId.keys() & newDataId.keys() 

536 self.assertTrue(commonKeys) 

537 self.assertEqual( 

538 [newDataId[k] for k in commonKeys], 

539 [dataId[k] for k in commonKeys], 

540 ) 

541 # This should never "downgrade" from 

542 # Complete to Minimal or Expanded to Complete. 

543 if dataId.hasRecords(): 

544 self.assertTrue(newDataId.hasRecords()) 

545 if dataId.hasFull(): 

546 self.assertTrue(newDataId.hasFull()) 

547 # Start from a complete data ID, and pass its values in via several 

548 # different ways that should be equivalent. 

549 for dataId in split.complete: 

550 # Split the keys (dimension names) into two random subsets, so 

551 # we can pass some as kwargs below. 

552 keys1 = set( 

553 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

554 ) 

555 keys2 = dataId.graph.dimensions.names - keys1 

556 newCompleteDataIds = [ 

557 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

558 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

559 DataCoordinate.standardize( 

560 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

561 ), 

562 DataCoordinate.standardize( 

563 DataCoordinate.makeEmpty(dataId.graph.universe), 

564 graph=dataId.graph, 

565 **dataId.full.byName(), 

566 ), 

567 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

568 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

569 DataCoordinate.standardize( 

570 {k: dataId[k] for k in keys1}, 

571 universe=dataId.universe, 

572 **{k: dataId[k] for k in keys2}, 

573 ), 

574 DataCoordinate.standardize( 

575 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

576 ), 

577 ] 

578 for newDataId in newCompleteDataIds: 

579 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

580 self.assertEqual(dataId, newDataId) 

581 self.assertTrue(newDataId.hasFull()) 

582 

583 def testUnion(self): 

584 """Test `DataCoordinate.union`.""" 

585 # Make test graphs to combine; mostly random, but with a few explicit 

586 # cases to make sure certain edge cases are covered. 

587 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

588 graphs.append(self.allDataIds.universe["visit"].graph) 

589 graphs.append(self.allDataIds.universe["detector"].graph) 

590 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

591 graphs.append(self.allDataIds.universe["band"].graph) 

592 # Iterate over all combinations, including the same graph with itself. 

593 for graph1, graph2 in itertools.product(graphs, repeat=2): 

594 parentDataIds = self.randomDataIds(n=1) 

595 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

596 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

597 (parentDataId,) = parentDataIds 

598 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

599 unioned = lhs.union(rhs) 

600 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

601 self.assertEqual(unioned.graph, graph1.union(graph2)) 

602 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

603 if unioned.hasFull(): 

604 self.assertEqual(unioned.subset(lhs.graph), lhs) 

605 self.assertEqual(unioned.subset(rhs.graph), rhs) 

606 if lhs.hasFull() and rhs.hasFull(): 

607 self.assertTrue(unioned.hasFull()) 

608 if lhs.graph >= unioned.graph and lhs.hasFull(): 

609 self.assertTrue(unioned.hasFull()) 

610 if lhs.hasRecords(): 

611 self.assertTrue(unioned.hasRecords()) 

612 if rhs.graph >= unioned.graph and rhs.hasFull(): 

613 self.assertTrue(unioned.hasFull()) 

614 if rhs.hasRecords(): 

615 self.assertTrue(unioned.hasRecords()) 

616 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

617 self.assertTrue(unioned.hasFull()) 

618 if lhs.hasRecords() and rhs.hasRecords(): 

619 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements: 

620 self.assertTrue(unioned.hasRecords()) 

621 

622 def testRegions(self): 

623 """Test that data IDs for a few known dimensions have the expected 

624 regions. 

625 """ 

626 for dataId in self.randomDataIds(n=4).subset( 

627 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

628 ): 

629 self.assertIsNotNone(dataId.region) 

630 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

631 self.assertEqual(dataId.region, dataId.records["visit"].region) 

632 for dataId in self.randomDataIds(n=4).subset( 

633 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

634 ): 

635 self.assertIsNotNone(dataId.region) 

636 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

637 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

638 for dataId in self.randomDataIds(n=4).subset( 

639 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

640 ): 

641 self.assertIsNotNone(dataId.region) 

642 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

643 self.assertEqual(dataId.region, dataId.records["tract"].region) 

644 for dataId in self.randomDataIds(n=4).subset( 

645 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

646 ): 

647 self.assertIsNotNone(dataId.region) 

648 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

649 self.assertEqual(dataId.region, dataId.records["patch"].region) 

650 

651 def testTimespans(self): 

652 """Test that data IDs for a few known dimensions have the expected 

653 timespans. 

654 """ 

655 for dataId in self.randomDataIds(n=4).subset( 

656 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

657 ): 

658 self.assertIsNotNone(dataId.timespan) 

659 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

660 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

661 # Also test the case for non-temporal DataIds. 

662 for dataId in self.randomDataIds(n=4).subset( 

663 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

664 ): 

665 self.assertIsNone(dataId.timespan) 

666 

667 def testIterableStatusFlags(self): 

668 """Test that DataCoordinateSet and DataCoordinateSequence compute 

669 their hasFull and hasRecords flags correctly from their elements. 

670 """ 

671 dataIds = self.randomDataIds(n=10) 

672 split = self.splitByStateFlags(dataIds) 

673 for cls in (DataCoordinateSet, DataCoordinateSequence): 

674 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

675 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

676 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

677 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

678 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

679 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

680 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

681 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

682 with self.assertRaises(ValueError): 

683 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

684 self.assertEqual( 

685 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

686 ) 

687 self.assertEqual( 

688 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

689 ) 

690 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

691 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

692 with self.assertRaises(ValueError): 

693 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

694 if dataIds.graph.implied: 

695 with self.assertRaises(ValueError): 

696 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

697 

698 def testSetOperations(self): 

699 """Test for self-consistency across DataCoordinateSet's operations.""" 

700 c = self.randomDataIds(n=10).toSet() 

701 a = self.randomDataIds(n=20).toSet() | c 

702 b = self.randomDataIds(n=20).toSet() | c 

703 # Make sure we don't have a particularly unlucky random seed, since 

704 # that would make a lot of this test uninteresting. 

705 self.assertNotEqual(a, b) 

706 self.assertGreater(len(a), 0) 

707 self.assertGreater(len(b), 0) 

708 # The rest of the tests should not depend on the random seed. 

709 self.assertEqual(a, a) 

710 self.assertNotEqual(a, a.toSequence()) 

711 self.assertEqual(a, a.toSequence().toSet()) 

712 self.assertEqual(a, a.toSequence().toSet()) 

713 self.assertEqual(b, b) 

714 self.assertNotEqual(b, b.toSequence()) 

715 self.assertEqual(b, b.toSequence().toSet()) 

716 self.assertEqual(a & b, a.intersection(b)) 

717 self.assertLessEqual(a & b, a) 

718 self.assertLessEqual(a & b, b) 

719 self.assertEqual(a | b, a.union(b)) 

720 self.assertGreaterEqual(a | b, a) 

721 self.assertGreaterEqual(a | b, b) 

722 self.assertEqual(a - b, a.difference(b)) 

723 self.assertLessEqual(a - b, a) 

724 self.assertLessEqual(b - a, b) 

725 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

726 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

727 

728 

729if __name__ == "__main__": 729 ↛ 730line 729 didn't jump to line 730, because the condition on line 729 was never true

730 unittest.main()