Coverage for tests/test_dimensions.py: 10%

419 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-28 10:37 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import copy 

23import itertools 

24import os 

25import pickle 

26import unittest 

27from dataclasses import dataclass 

28from random import Random 

29from typing import Iterator, Optional 

30 

31import lsst.sphgeom 

32from lsst.daf.butler import ( 

33 Config, 

34 DataCoordinate, 

35 DataCoordinateSequence, 

36 DataCoordinateSet, 

37 Dimension, 

38 DimensionConfig, 

39 DimensionGraph, 

40 DimensionUniverse, 

41 NamedKeyDict, 

42 NamedValueSet, 

43 Registry, 

44 TimespanDatabaseRepresentation, 

45 YamlRepoImportBackend, 

46) 

47from lsst.daf.butler.registry import RegistryConfig 

48 

49DIMENSION_DATA_FILE = os.path.normpath( 

50 os.path.join(os.path.dirname(__file__), "data", "registry", "hsc-rc2-subset.yaml") 

51) 

52 

53 

54def loadDimensionData() -> DataCoordinateSequence: 

55 """Load dimension data from an export file included in the code repository. 

56 

57 Returns 

58 ------- 

59 dataIds : `DataCoordinateSet` 

60 A set containing all data IDs in the export file. 

61 """ 

62 # Create an in-memory SQLite database and Registry just to import the YAML 

63 # data and retreive it as a set of DataCoordinate objects. 

64 config = RegistryConfig() 

65 config["db"] = "sqlite://" 

66 registry = Registry.createFromConfig(config) 

67 with open(DIMENSION_DATA_FILE, "r") as stream: 

68 backend = YamlRepoImportBackend(stream, registry) 

69 backend.register() 

70 backend.load(datastore=None) 

71 dimensions = DimensionGraph(registry.dimensions, names=["visit", "detector", "tract", "patch"]) 

72 return registry.queryDataIds(dimensions).expanded().toSequence() 

73 

74 

75class DimensionTestCase(unittest.TestCase): 

76 """Tests for dimensions. 

77 

78 All tests here rely on the content of ``config/dimensions.yaml``, either 

79 to test that the definitions there are read in properly or just as generic 

80 data for testing various operations. 

81 """ 

82 

83 def setUp(self): 

84 self.universe = DimensionUniverse() 

85 

86 def checkGraphInvariants(self, graph): 

87 elements = list(graph.elements) 

88 for n, element in enumerate(elements): 

89 # Ordered comparisons on graphs behave like sets. 

90 self.assertLessEqual(element.graph, graph) 

91 # Ordered comparisons on elements correspond to the ordering within 

92 # a DimensionUniverse (topological, with deterministic 

93 # tiebreakers). 

94 for other in elements[:n]: 

95 self.assertLess(other, element) 

96 self.assertLessEqual(other, element) 

97 for other in elements[n + 1 :]: 

98 self.assertGreater(other, element) 

99 self.assertGreaterEqual(other, element) 

100 if isinstance(element, Dimension): 

101 self.assertEqual(element.graph.required, element.required) 

102 self.assertEqual(DimensionGraph(self.universe, graph.required), graph) 

103 self.assertCountEqual( 

104 graph.required, 

105 [ 

106 dimension 

107 for dimension in graph.dimensions 

108 if not any(dimension in other.graph.implied for other in graph.elements) 

109 ], 

110 ) 

111 self.assertCountEqual(graph.implied, graph.dimensions - graph.required) 

112 self.assertCountEqual( 

113 graph.dimensions, [element for element in graph.elements if isinstance(element, Dimension)] 

114 ) 

115 self.assertCountEqual(graph.dimensions, itertools.chain(graph.required, graph.implied)) 

116 # Check primary key traversal order: each element should follow any it 

117 # requires, and element that is implied by any other in the graph 

118 # follow at least one of those. 

119 seen = NamedValueSet() 

120 for element in graph.primaryKeyTraversalOrder: 

121 with self.subTest(required=graph.required, implied=graph.implied, element=element): 

122 seen.add(element) 

123 self.assertLessEqual(element.graph.required, seen) 

124 if element in graph.implied: 

125 self.assertTrue(any(element in s.implied for s in seen)) 

126 self.assertCountEqual(seen, graph.elements) 

127 

128 def testConfigPresent(self): 

129 config = self.universe.dimensionConfig 

130 self.assertIsInstance(config, DimensionConfig) 

131 

132 def testCompatibility(self): 

133 # Simple check that should always be true. 

134 self.assertTrue(self.universe.isCompatibleWith(self.universe)) 

135 

136 # Create a universe like the default universe but with a different 

137 # version number. 

138 clone = self.universe.dimensionConfig.copy() 

139 clone["version"] = clone["version"] + 1_000_000 # High version number 

140 universe_clone = DimensionUniverse(config=clone) 

141 with self.assertLogs("lsst.daf.butler.core.dimensions", "INFO") as cm: 

142 self.assertTrue(self.universe.isCompatibleWith(universe_clone)) 

143 self.assertIn("differing versions", "\n".join(cm.output)) 

144 

145 # Create completely incompatible universe. 

146 config = Config( 

147 { 

148 "version": 1, 

149 "namespace": "compat_test", 

150 "skypix": { 

151 "common": "htm7", 

152 "htm": { 

153 "class": "lsst.sphgeom.HtmPixelization", 

154 "max_level": 24, 

155 }, 

156 }, 

157 "elements": { 

158 "A": { 

159 "keys": [ 

160 { 

161 "name": "id", 

162 "type": "int", 

163 } 

164 ], 

165 "storage": { 

166 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

167 }, 

168 }, 

169 "B": { 

170 "keys": [ 

171 { 

172 "name": "id", 

173 "type": "int", 

174 } 

175 ], 

176 "storage": { 

177 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

178 }, 

179 }, 

180 }, 

181 "packers": {}, 

182 } 

183 ) 

184 universe2 = DimensionUniverse(config=config) 

185 self.assertFalse(universe2.isCompatibleWith(self.universe)) 

186 

187 def testVersion(self): 

188 self.assertEqual(self.universe.namespace, "daf_butler") 

189 # Test was added starting at version 2. 

190 self.assertGreaterEqual(self.universe.version, 2) 

191 

192 def testConfigRead(self): 

193 self.assertEqual( 

194 set(self.universe.getStaticDimensions().names), 

195 { 

196 "instrument", 

197 "visit", 

198 "visit_system", 

199 "exposure", 

200 "detector", 

201 "physical_filter", 

202 "band", 

203 "subfilter", 

204 "skymap", 

205 "tract", 

206 "patch", 

207 } 

208 | {f"htm{level}" for level in range(25)} 

209 | {f"healpix{level}" for level in range(18)}, 

210 ) 

211 

212 def testGraphs(self): 

213 self.checkGraphInvariants(self.universe.empty) 

214 for element in self.universe.getStaticElements(): 

215 self.checkGraphInvariants(element.graph) 

216 

217 def testInstrumentDimensions(self): 

218 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

219 self.assertCountEqual( 

220 graph.dimensions.names, 

221 ("instrument", "exposure", "detector", "visit", "physical_filter", "band"), 

222 ) 

223 self.assertCountEqual(graph.required.names, ("instrument", "exposure", "detector", "visit")) 

224 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

225 self.assertCountEqual( 

226 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

227 ) 

228 self.assertCountEqual(graph.governors.names, {"instrument"}) 

229 

230 def testCalibrationDimensions(self): 

231 graph = DimensionGraph(self.universe, names=("physical_filter", "detector")) 

232 self.assertCountEqual(graph.dimensions.names, ("instrument", "detector", "physical_filter", "band")) 

233 self.assertCountEqual(graph.required.names, ("instrument", "detector", "physical_filter")) 

234 self.assertCountEqual(graph.implied.names, ("band",)) 

235 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

236 self.assertCountEqual(graph.governors.names, {"instrument"}) 

237 

238 def testObservationDimensions(self): 

239 graph = DimensionGraph(self.universe, names=("exposure", "detector", "visit")) 

240 self.assertCountEqual( 

241 graph.dimensions.names, 

242 ("instrument", "detector", "visit", "exposure", "physical_filter", "band"), 

243 ) 

244 self.assertCountEqual(graph.required.names, ("instrument", "detector", "exposure", "visit")) 

245 self.assertCountEqual(graph.implied.names, ("physical_filter", "band")) 

246 self.assertCountEqual( 

247 graph.elements.names - graph.dimensions.names, ("visit_detector_region", "visit_definition") 

248 ) 

249 self.assertCountEqual(graph.spatial.names, ("observation_regions",)) 

250 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

251 self.assertCountEqual(graph.governors.names, {"instrument"}) 

252 self.assertEqual(graph.spatial.names, {"observation_regions"}) 

253 self.assertEqual(graph.temporal.names, {"observation_timespans"}) 

254 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["instrument"]) 

255 self.assertEqual(next(iter(graph.temporal)).governor, self.universe["instrument"]) 

256 

257 def testSkyMapDimensions(self): 

258 graph = DimensionGraph(self.universe, names=("patch",)) 

259 self.assertCountEqual(graph.dimensions.names, ("skymap", "tract", "patch")) 

260 self.assertCountEqual(graph.required.names, ("skymap", "tract", "patch")) 

261 self.assertCountEqual(graph.implied.names, ()) 

262 self.assertCountEqual(graph.elements.names, graph.dimensions.names) 

263 self.assertCountEqual(graph.spatial.names, ("skymap_regions",)) 

264 self.assertCountEqual(graph.governors.names, {"skymap"}) 

265 self.assertEqual(graph.spatial.names, {"skymap_regions"}) 

266 self.assertEqual(next(iter(graph.spatial)).governor, self.universe["skymap"]) 

267 

268 def testSubsetCalculation(self): 

269 """Test that independent spatial and temporal options are computed 

270 correctly. 

271 """ 

272 graph = DimensionGraph( 

273 self.universe, names=("visit", "detector", "tract", "patch", "htm7", "exposure") 

274 ) 

275 self.assertCountEqual(graph.spatial.names, ("observation_regions", "skymap_regions", "htm")) 

276 self.assertCountEqual(graph.temporal.names, ("observation_timespans",)) 

277 

278 def testSchemaGeneration(self): 

279 tableSpecs = NamedKeyDict({}) 

280 for element in self.universe.getStaticElements(): 

281 if element.hasTable and element.viewOf is None: 

282 tableSpecs[element] = element.RecordClass.fields.makeTableSpec( 

283 TimespanReprClass=TimespanDatabaseRepresentation.Compound, 

284 ) 

285 for element, tableSpec in tableSpecs.items(): 

286 for dep in element.required: 

287 with self.subTest(element=element.name, dep=dep.name): 

288 if dep != element: 

289 self.assertIn(dep.name, tableSpec.fields) 

290 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

291 self.assertEqual(tableSpec.fields[dep.name].length, dep.primaryKey.length) 

292 self.assertEqual(tableSpec.fields[dep.name].nbytes, dep.primaryKey.nbytes) 

293 self.assertFalse(tableSpec.fields[dep.name].nullable) 

294 self.assertTrue(tableSpec.fields[dep.name].primaryKey) 

295 else: 

296 self.assertIn(element.primaryKey.name, tableSpec.fields) 

297 self.assertEqual( 

298 tableSpec.fields[element.primaryKey.name].dtype, dep.primaryKey.dtype 

299 ) 

300 self.assertEqual( 

301 tableSpec.fields[element.primaryKey.name].length, dep.primaryKey.length 

302 ) 

303 self.assertEqual( 

304 tableSpec.fields[element.primaryKey.name].nbytes, dep.primaryKey.nbytes 

305 ) 

306 self.assertFalse(tableSpec.fields[element.primaryKey.name].nullable) 

307 self.assertTrue(tableSpec.fields[element.primaryKey.name].primaryKey) 

308 for dep in element.implied: 

309 with self.subTest(element=element.name, dep=dep.name): 

310 self.assertIn(dep.name, tableSpec.fields) 

311 self.assertEqual(tableSpec.fields[dep.name].dtype, dep.primaryKey.dtype) 

312 self.assertFalse(tableSpec.fields[dep.name].primaryKey) 

313 for foreignKey in tableSpec.foreignKeys: 

314 self.assertIn(foreignKey.table, tableSpecs) 

315 self.assertIn(foreignKey.table, element.graph.dimensions.names) 

316 self.assertEqual(len(foreignKey.source), len(foreignKey.target)) 

317 for source, target in zip(foreignKey.source, foreignKey.target): 

318 self.assertIn(source, tableSpec.fields.names) 

319 self.assertIn(target, tableSpecs[foreignKey.table].fields.names) 

320 self.assertEqual( 

321 tableSpec.fields[source].dtype, tableSpecs[foreignKey.table].fields[target].dtype 

322 ) 

323 self.assertEqual( 

324 tableSpec.fields[source].length, tableSpecs[foreignKey.table].fields[target].length 

325 ) 

326 self.assertEqual( 

327 tableSpec.fields[source].nbytes, tableSpecs[foreignKey.table].fields[target].nbytes 

328 ) 

329 

330 def testPickling(self): 

331 # Pickling and copying should always yield the exact same object within 

332 # a single process (cross-process is impossible to test here). 

333 universe1 = DimensionUniverse() 

334 universe2 = pickle.loads(pickle.dumps(universe1)) 

335 universe3 = copy.copy(universe1) 

336 universe4 = copy.deepcopy(universe1) 

337 self.assertIs(universe1, universe2) 

338 self.assertIs(universe1, universe3) 

339 self.assertIs(universe1, universe4) 

340 for element1 in universe1.getStaticElements(): 

341 element2 = pickle.loads(pickle.dumps(element1)) 

342 self.assertIs(element1, element2) 

343 graph1 = element1.graph 

344 graph2 = pickle.loads(pickle.dumps(graph1)) 

345 self.assertIs(graph1, graph2) 

346 

347 

348@dataclass 

349class SplitByStateFlags: 

350 """A struct that separates data IDs with different states but the same 

351 values. 

352 """ 

353 

354 minimal: Optional[DataCoordinateSequence] = None 

355 """Data IDs that only contain values for required dimensions. 

356 

357 `DataCoordinateSequence.hasFull()` will return `True` for this if and only 

358 if ``minimal.graph.implied`` has no elements. 

359 `DataCoordinate.hasRecords()` will always return `False`. 

360 """ 

361 

362 complete: Optional[DataCoordinateSequence] = None 

363 """Data IDs that contain values for all dimensions. 

364 

365 `DataCoordinateSequence.hasFull()` will always `True` and 

366 `DataCoordinate.hasRecords()` will always return `True` for this attribute. 

367 """ 

368 

369 expanded: Optional[DataCoordinateSequence] = None 

370 """Data IDs that contain values for all dimensions as well as records. 

371 

372 `DataCoordinateSequence.hasFull()` and `DataCoordinate.hasRecords()` will 

373 always return `True` for this attribute. 

374 """ 

375 

376 def chain(self, n: Optional[int] = None) -> Iterator: 

377 """Iterate over the data IDs of different types. 

378 

379 Parameters 

380 ---------- 

381 n : `int`, optional 

382 If provided (`None` is default), iterate over only the ``nth`` 

383 data ID in each attribute. 

384 

385 Yields 

386 ------ 

387 dataId : `DataCoordinate` 

388 A data ID from one of the attributes in this struct. 

389 """ 

390 if n is None: 

391 s = slice(None, None) 

392 else: 

393 s = slice(n, n + 1) 

394 if self.minimal is not None: 

395 yield from self.minimal[s] 

396 if self.complete is not None: 

397 yield from self.complete[s] 

398 if self.expanded is not None: 

399 yield from self.expanded[s] 

400 

401 

402class DataCoordinateTestCase(unittest.TestCase): 

403 RANDOM_SEED = 10 

404 

405 @classmethod 

406 def setUpClass(cls): 

407 cls.allDataIds = loadDimensionData() 

408 

409 def setUp(self): 

410 self.rng = Random(self.RANDOM_SEED) 

411 

412 def randomDataIds(self, n: int, dataIds: Optional[DataCoordinateSequence] = None): 

413 """Select random data IDs from those loaded from test data. 

414 

415 Parameters 

416 ---------- 

417 n : `int` 

418 Number of data IDs to select. 

419 dataIds : `DataCoordinateSequence`, optional 

420 Data IDs to select from. Defaults to ``self.allDataIds``. 

421 

422 Returns 

423 ------- 

424 selected : `DataCoordinateSequence` 

425 ``n`` Data IDs randomly selected from ``dataIds`` with replacement. 

426 """ 

427 if dataIds is None: 

428 dataIds = self.allDataIds 

429 return DataCoordinateSequence( 

430 self.rng.sample(dataIds, n), 

431 graph=dataIds.graph, 

432 hasFull=dataIds.hasFull(), 

433 hasRecords=dataIds.hasRecords(), 

434 check=False, 

435 ) 

436 

437 def randomDimensionSubset(self, n: int = 3, graph: Optional[DimensionGraph] = None) -> DimensionGraph: 

438 """Generate a random `DimensionGraph` that has a subset of the 

439 dimensions in a given one. 

440 

441 Parameters 

442 ---------- 

443 n : `int` 

444 Number of dimensions to select, before automatic expansion by 

445 `DimensionGraph`. 

446 dataIds : `DimensionGraph`, optional 

447 Dimensions to select from. Defaults to ``self.allDataIds.graph``. 

448 

449 Returns 

450 ------- 

451 selected : `DimensionGraph` 

452 ``n`` or more dimensions randomly selected from ``graph`` with 

453 replacement. 

454 """ 

455 if graph is None: 

456 graph = self.allDataIds.graph 

457 return DimensionGraph( 

458 graph.universe, names=self.rng.sample(list(graph.dimensions.names), max(n, len(graph.dimensions))) 

459 ) 

460 

461 def splitByStateFlags( 

462 self, 

463 dataIds: Optional[DataCoordinateSequence] = None, 

464 *, 

465 expanded: bool = True, 

466 complete: bool = True, 

467 minimal: bool = True, 

468 ) -> SplitByStateFlags: 

469 """Given a sequence of data IDs, generate new equivalent sequences 

470 containing less information. 

471 

472 Parameters 

473 ---------- 

474 dataIds : `DataCoordinateSequence`, optional. 

475 Data IDs to start from. Defaults to ``self.allDataIds``. 

476 ``dataIds.hasRecords()`` and ``dataIds.hasFull()`` must both return 

477 `True`. 

478 expanded : `bool`, optional 

479 If `True` (default) include the original data IDs that contain all 

480 information in the result. 

481 complete : `bool`, optional 

482 If `True` (default) include data IDs for which ``hasFull()`` 

483 returns `True` but ``hasRecords()`` does not. 

484 minimal : `bool`, optional 

485 If `True` (default) include data IDS that only contain values for 

486 required dimensions, for which ``hasFull()`` may not return `True`. 

487 

488 Returns 

489 ------- 

490 split : `SplitByStateFlags` 

491 A dataclass holding the indicated data IDs in attributes that 

492 correspond to the boolean keyword arguments. 

493 """ 

494 if dataIds is None: 

495 dataIds = self.allDataIds 

496 assert dataIds.hasFull() and dataIds.hasRecords() 

497 result = SplitByStateFlags(expanded=dataIds) 

498 if complete: 

499 result.complete = DataCoordinateSequence( 

500 [DataCoordinate.standardize(e.full.byName(), graph=dataIds.graph) for e in result.expanded], 

501 graph=dataIds.graph, 

502 ) 

503 self.assertTrue(result.complete.hasFull()) 

504 self.assertFalse(result.complete.hasRecords()) 

505 if minimal: 

506 result.minimal = DataCoordinateSequence( 

507 [DataCoordinate.standardize(e.byName(), graph=dataIds.graph) for e in result.expanded], 

508 graph=dataIds.graph, 

509 ) 

510 self.assertEqual(result.minimal.hasFull(), not dataIds.graph.implied) 

511 self.assertFalse(result.minimal.hasRecords()) 

512 if not expanded: 

513 result.expanded = None 

514 return result 

515 

516 def testMappingInterface(self): 

517 """Test that the mapping interface in `DataCoordinate` and (when 

518 applicable) its ``full`` property are self-consistent and consistent 

519 with the ``graph`` property. 

520 """ 

521 for n in range(5): 

522 dimensions = self.randomDimensionSubset() 

523 dataIds = self.randomDataIds(n=1).subset(dimensions) 

524 split = self.splitByStateFlags(dataIds) 

525 for dataId in split.chain(): 

526 with self.subTest(dataId=dataId): 

527 self.assertEqual(list(dataId.values()), [dataId[d] for d in dataId.keys()]) 

528 self.assertEqual(list(dataId.values()), [dataId[d.name] for d in dataId.keys()]) 

529 self.assertEqual(dataId.keys(), dataId.graph.required) 

530 for dataId in itertools.chain(split.complete, split.expanded): 

531 with self.subTest(dataId=dataId): 

532 self.assertTrue(dataId.hasFull()) 

533 self.assertEqual(dataId.graph.dimensions, dataId.full.keys()) 

534 self.assertEqual(list(dataId.full.values()), [dataId[k] for k in dataId.graph.dimensions]) 

535 

536 def test_pickle(self): 

537 for n in range(5): 

538 dimensions = self.randomDimensionSubset() 

539 dataIds = self.randomDataIds(n=1).subset(dimensions) 

540 split = self.splitByStateFlags(dataIds) 

541 for data_id in split.chain(): 

542 s = pickle.dumps(data_id) 

543 read_data_id = pickle.loads(s) 

544 self.assertEqual(data_id, read_data_id) 

545 self.assertEqual(data_id.hasFull(), read_data_id.hasFull()) 

546 self.assertEqual(data_id.hasRecords(), read_data_id.hasRecords()) 

547 if data_id.hasFull(): 

548 self.assertEqual(data_id.full, read_data_id.full) 

549 if data_id.hasRecords(): 

550 self.assertEqual(data_id.records, read_data_id.records) 

551 

552 def test_record_attributes(self): 

553 """Test that dimension records are available as attributes on expanded 

554 data coordinates. 

555 """ 

556 for n in range(5): 

557 dimensions = self.randomDimensionSubset() 

558 dataIds = self.randomDataIds(n=1).subset(dimensions) 

559 split = self.splitByStateFlags(dataIds) 

560 for data_id in split.expanded: 

561 for element in data_id.graph.elements: 

562 self.assertIs(getattr(data_id, element.name), data_id.records[element.name]) 

563 self.assertIn(element.name, dir(data_id)) 

564 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

565 getattr(data_id, "not_a_dimension_name") 

566 for data_id in itertools.chain(split.minimal, split.complete): 

567 for element in data_id.graph.elements: 

568 with self.assertRaisesRegex(AttributeError, "only available on expanded DataCoordinates"): 

569 getattr(data_id, element.name) 

570 with self.assertRaisesRegex(AttributeError, "^not_a_dimension_name$"): 

571 getattr(data_id, "not_a_dimension_name") 

572 

573 def testEquality(self): 

574 """Test that different `DataCoordinate` instances with different state 

575 flags can be compared with each other and other mappings. 

576 """ 

577 dataIds = self.randomDataIds(n=2) 

578 split = self.splitByStateFlags(dataIds) 

579 # Iterate over all combinations of different states of DataCoordinate, 

580 # with the same underlying data ID values. 

581 for a0, b0 in itertools.combinations(split.chain(0), 2): 

582 self.assertEqual(a0, b0) 

583 self.assertEqual(a0, b0.byName()) 

584 self.assertEqual(a0.byName(), b0) 

585 # Same thing, for a different data ID value. 

586 for a1, b1 in itertools.combinations(split.chain(1), 2): 

587 self.assertEqual(a1, b1) 

588 self.assertEqual(a1, b1.byName()) 

589 self.assertEqual(a1.byName(), b1) 

590 # Iterate over all combinations of different states of DataCoordinate, 

591 # with different underlying data ID values. 

592 for a0, b1 in itertools.product(split.chain(0), split.chain(1)): 

593 self.assertNotEqual(a0, b1) 

594 self.assertNotEqual(a1, b0) 

595 self.assertNotEqual(a0, b1.byName()) 

596 self.assertNotEqual(a0.byName(), b1) 

597 self.assertNotEqual(a1, b0.byName()) 

598 self.assertNotEqual(a1.byName(), b0) 

599 

600 def testStandardize(self): 

601 """Test constructing a DataCoordinate from many different kinds of 

602 input via `DataCoordinate.standardize` and `DataCoordinate.subset`. 

603 """ 

604 for n in range(5): 

605 dimensions = self.randomDimensionSubset() 

606 dataIds = self.randomDataIds(n=1).subset(dimensions) 

607 split = self.splitByStateFlags(dataIds) 

608 for m, dataId in enumerate(split.chain()): 

609 # Passing in any kind of DataCoordinate alone just returns 

610 # that object. 

611 self.assertIs(dataId, DataCoordinate.standardize(dataId)) 

612 # Same if we also explicitly pass the dimensions we want. 

613 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph)) 

614 # Same if we pass the dimensions and some irrelevant 

615 # kwargs. 

616 self.assertIs(dataId, DataCoordinate.standardize(dataId, graph=dataId.graph, htm7=12)) 

617 # Test constructing a new data ID from this one with a 

618 # subset of the dimensions. 

619 # This is not possible for some combinations of 

620 # dimensions if hasFull is False (see 

621 # `DataCoordinate.subset` docs). 

622 newDimensions = self.randomDimensionSubset(n=1, graph=dataId.graph) 

623 if dataId.hasFull() or dataId.graph.required.issuperset(newDimensions.required): 

624 newDataIds = [ 

625 dataId.subset(newDimensions), 

626 DataCoordinate.standardize(dataId, graph=newDimensions), 

627 DataCoordinate.standardize(dataId, graph=newDimensions, htm7=12), 

628 ] 

629 for newDataId in newDataIds: 

630 with self.subTest(newDataId=newDataId, type=type(dataId)): 

631 commonKeys = dataId.keys() & newDataId.keys() 

632 self.assertTrue(commonKeys) 

633 self.assertEqual( 

634 [newDataId[k] for k in commonKeys], 

635 [dataId[k] for k in commonKeys], 

636 ) 

637 # This should never "downgrade" from 

638 # Complete to Minimal or Expanded to Complete. 

639 if dataId.hasRecords(): 

640 self.assertTrue(newDataId.hasRecords()) 

641 if dataId.hasFull(): 

642 self.assertTrue(newDataId.hasFull()) 

643 # Start from a complete data ID, and pass its values in via several 

644 # different ways that should be equivalent. 

645 for dataId in split.complete: 

646 # Split the keys (dimension names) into two random subsets, so 

647 # we can pass some as kwargs below. 

648 keys1 = set( 

649 self.rng.sample(list(dataId.graph.dimensions.names), len(dataId.graph.dimensions) // 2) 

650 ) 

651 keys2 = dataId.graph.dimensions.names - keys1 

652 newCompleteDataIds = [ 

653 DataCoordinate.standardize(dataId.full.byName(), universe=dataId.universe), 

654 DataCoordinate.standardize(dataId.full.byName(), graph=dataId.graph), 

655 DataCoordinate.standardize( 

656 DataCoordinate.makeEmpty(dataId.graph.universe), **dataId.full.byName() 

657 ), 

658 DataCoordinate.standardize( 

659 DataCoordinate.makeEmpty(dataId.graph.universe), 

660 graph=dataId.graph, 

661 **dataId.full.byName(), 

662 ), 

663 DataCoordinate.standardize(**dataId.full.byName(), universe=dataId.universe), 

664 DataCoordinate.standardize(graph=dataId.graph, **dataId.full.byName()), 

665 DataCoordinate.standardize( 

666 {k: dataId[k] for k in keys1}, 

667 universe=dataId.universe, 

668 **{k: dataId[k] for k in keys2}, 

669 ), 

670 DataCoordinate.standardize( 

671 {k: dataId[k] for k in keys1}, graph=dataId.graph, **{k: dataId[k] for k in keys2} 

672 ), 

673 ] 

674 for newDataId in newCompleteDataIds: 

675 with self.subTest(dataId=dataId, newDataId=newDataId, type=type(dataId)): 

676 self.assertEqual(dataId, newDataId) 

677 self.assertTrue(newDataId.hasFull()) 

678 

679 def testUnion(self): 

680 """Test `DataCoordinate.union`.""" 

681 # Make test graphs to combine; mostly random, but with a few explicit 

682 # cases to make sure certain edge cases are covered. 

683 graphs = [self.randomDimensionSubset(n=2) for i in range(2)] 

684 graphs.append(self.allDataIds.universe["visit"].graph) 

685 graphs.append(self.allDataIds.universe["detector"].graph) 

686 graphs.append(self.allDataIds.universe["physical_filter"].graph) 

687 graphs.append(self.allDataIds.universe["band"].graph) 

688 # Iterate over all combinations, including the same graph with itself. 

689 for graph1, graph2 in itertools.product(graphs, repeat=2): 

690 parentDataIds = self.randomDataIds(n=1) 

691 split1 = self.splitByStateFlags(parentDataIds.subset(graph1)) 

692 split2 = self.splitByStateFlags(parentDataIds.subset(graph2)) 

693 (parentDataId,) = parentDataIds 

694 for lhs, rhs in itertools.product(split1.chain(), split2.chain()): 

695 unioned = lhs.union(rhs) 

696 with self.subTest(lhs=lhs, rhs=rhs, unioned=unioned): 

697 self.assertEqual(unioned.graph, graph1.union(graph2)) 

698 self.assertEqual(unioned, parentDataId.subset(unioned.graph)) 

699 if unioned.hasFull(): 

700 self.assertEqual(unioned.subset(lhs.graph), lhs) 

701 self.assertEqual(unioned.subset(rhs.graph), rhs) 

702 if lhs.hasFull() and rhs.hasFull(): 

703 self.assertTrue(unioned.hasFull()) 

704 if lhs.graph >= unioned.graph and lhs.hasFull(): 

705 self.assertTrue(unioned.hasFull()) 

706 if lhs.hasRecords(): 

707 self.assertTrue(unioned.hasRecords()) 

708 if rhs.graph >= unioned.graph and rhs.hasFull(): 

709 self.assertTrue(unioned.hasFull()) 

710 if rhs.hasRecords(): 

711 self.assertTrue(unioned.hasRecords()) 

712 if lhs.graph.required | rhs.graph.required >= unioned.graph.dimensions: 

713 self.assertTrue(unioned.hasFull()) 

714 if lhs.hasRecords() and rhs.hasRecords(): 

715 if lhs.graph.elements | rhs.graph.elements >= unioned.graph.elements: 

716 self.assertTrue(unioned.hasRecords()) 

717 

718 def testRegions(self): 

719 """Test that data IDs for a few known dimensions have the expected 

720 regions. 

721 """ 

722 for dataId in self.randomDataIds(n=4).subset( 

723 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

724 ): 

725 self.assertIsNotNone(dataId.region) 

726 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

727 self.assertEqual(dataId.region, dataId.records["visit"].region) 

728 for dataId in self.randomDataIds(n=4).subset( 

729 DimensionGraph(self.allDataIds.universe, names=["visit", "detector"]) 

730 ): 

731 self.assertIsNotNone(dataId.region) 

732 self.assertEqual(dataId.graph.spatial.names, {"observation_regions"}) 

733 self.assertEqual(dataId.region, dataId.records["visit_detector_region"].region) 

734 for dataId in self.randomDataIds(n=4).subset( 

735 DimensionGraph(self.allDataIds.universe, names=["tract"]) 

736 ): 

737 self.assertIsNotNone(dataId.region) 

738 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

739 self.assertEqual(dataId.region, dataId.records["tract"].region) 

740 for dataId in self.randomDataIds(n=4).subset( 

741 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

742 ): 

743 self.assertIsNotNone(dataId.region) 

744 self.assertEqual(dataId.graph.spatial.names, {"skymap_regions"}) 

745 self.assertEqual(dataId.region, dataId.records["patch"].region) 

746 for data_id in self.randomDataIds(n=1).subset( 

747 DimensionGraph(self.allDataIds.universe, names=["visit", "tract"]) 

748 ): 

749 self.assertEqual(data_id.region.relate(data_id.records["visit"].region), lsst.sphgeom.WITHIN) 

750 self.assertEqual(data_id.region.relate(data_id.records["tract"].region), lsst.sphgeom.WITHIN) 

751 

752 def testTimespans(self): 

753 """Test that data IDs for a few known dimensions have the expected 

754 timespans. 

755 """ 

756 for dataId in self.randomDataIds(n=4).subset( 

757 DimensionGraph(self.allDataIds.universe, names=["visit"]) 

758 ): 

759 self.assertIsNotNone(dataId.timespan) 

760 self.assertEqual(dataId.graph.temporal.names, {"observation_timespans"}) 

761 self.assertEqual(dataId.timespan, dataId.records["visit"].timespan) 

762 self.assertEqual(dataId.timespan, dataId.visit.timespan) 

763 # Also test the case for non-temporal DataIds. 

764 for dataId in self.randomDataIds(n=4).subset( 

765 DimensionGraph(self.allDataIds.universe, names=["patch"]) 

766 ): 

767 self.assertIsNone(dataId.timespan) 

768 

769 def testIterableStatusFlags(self): 

770 """Test that DataCoordinateSet and DataCoordinateSequence compute 

771 their hasFull and hasRecords flags correctly from their elements. 

772 """ 

773 dataIds = self.randomDataIds(n=10) 

774 split = self.splitByStateFlags(dataIds) 

775 for cls in (DataCoordinateSet, DataCoordinateSequence): 

776 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasFull()) 

777 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasFull()) 

778 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=True).hasRecords()) 

779 self.assertTrue(cls(split.expanded, graph=dataIds.graph, check=False).hasRecords()) 

780 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=True).hasFull()) 

781 self.assertTrue(cls(split.complete, graph=dataIds.graph, check=False).hasFull()) 

782 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=True).hasRecords()) 

783 self.assertFalse(cls(split.complete, graph=dataIds.graph, check=False).hasRecords()) 

784 with self.assertRaises(ValueError): 

785 cls(split.complete, graph=dataIds.graph, hasRecords=True, check=True) 

786 self.assertEqual( 

787 cls(split.minimal, graph=dataIds.graph, check=True).hasFull(), not dataIds.graph.implied 

788 ) 

789 self.assertEqual( 

790 cls(split.minimal, graph=dataIds.graph, check=False).hasFull(), not dataIds.graph.implied 

791 ) 

792 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=True).hasRecords()) 

793 self.assertFalse(cls(split.minimal, graph=dataIds.graph, check=False).hasRecords()) 

794 with self.assertRaises(ValueError): 

795 cls(split.minimal, graph=dataIds.graph, hasRecords=True, check=True) 

796 if dataIds.graph.implied: 

797 with self.assertRaises(ValueError): 

798 cls(split.minimal, graph=dataIds.graph, hasFull=True, check=True) 

799 

800 def testSetOperations(self): 

801 """Test for self-consistency across DataCoordinateSet's operations.""" 

802 c = self.randomDataIds(n=10).toSet() 

803 a = self.randomDataIds(n=20).toSet() | c 

804 b = self.randomDataIds(n=20).toSet() | c 

805 # Make sure we don't have a particularly unlucky random seed, since 

806 # that would make a lot of this test uninteresting. 

807 self.assertNotEqual(a, b) 

808 self.assertGreater(len(a), 0) 

809 self.assertGreater(len(b), 0) 

810 # The rest of the tests should not depend on the random seed. 

811 self.assertEqual(a, a) 

812 self.assertNotEqual(a, a.toSequence()) 

813 self.assertEqual(a, a.toSequence().toSet()) 

814 self.assertEqual(a, a.toSequence().toSet()) 

815 self.assertEqual(b, b) 

816 self.assertNotEqual(b, b.toSequence()) 

817 self.assertEqual(b, b.toSequence().toSet()) 

818 self.assertEqual(a & b, a.intersection(b)) 

819 self.assertLessEqual(a & b, a) 

820 self.assertLessEqual(a & b, b) 

821 self.assertEqual(a | b, a.union(b)) 

822 self.assertGreaterEqual(a | b, a) 

823 self.assertGreaterEqual(a | b, b) 

824 self.assertEqual(a - b, a.difference(b)) 

825 self.assertLessEqual(a - b, a) 

826 self.assertLessEqual(b - a, b) 

827 self.assertEqual(a ^ b, a.symmetric_difference(b)) 

828 self.assertGreaterEqual(a ^ b, (a | b) - (a & b)) 

829 

830 

831if __name__ == "__main__": 831 ↛ 832line 831 didn't jump to line 832, because the condition on line 831 was never true

832 unittest.main()