Coverage for tests/test_pipeline_graph.py: 15%

511 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-13 09:52 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests of things related to the GraphBuilder class.""" 

29 

30import copy 

31import io 

32import logging 

33import unittest 

34from typing import Any 

35 

36import lsst.pipe.base.automatic_connection_constants as acc 

37import lsst.utils.tests 

38from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory 

39from lsst.daf.butler.registry import MissingDatasetTypeError 

40from lsst.pipe.base.pipeline_graph import ( 

41 ConnectionTypeConsistencyError, 

42 DuplicateOutputError, 

43 Edge, 

44 EdgesChangedError, 

45 IncompatibleDatasetTypeError, 

46 NodeKey, 

47 NodeType, 

48 PipelineGraph, 

49 PipelineGraphError, 

50 TaskImportMode, 

51 UnresolvedGraphError, 

52) 

53from lsst.pipe.base.tests.mocks import ( 

54 DynamicConnectionConfig, 

55 DynamicTestPipelineTask, 

56 DynamicTestPipelineTaskConfig, 

57 get_mock_name, 

58) 

59 

60_LOG = logging.getLogger(__name__) 

61 

62 

63class MockRegistry: 

64 """A test-utility stand-in for lsst.daf.butler.Registry that just knows 

65 how to get dataset types. 

66 """ 

67 

68 def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None: 

69 self.dimensions = dimensions 

70 self._dataset_types = dataset_types 

71 

72 def getDatasetType(self, name: str) -> DatasetType: 

73 try: 

74 return self._dataset_types[name] 

75 except KeyError: 

76 raise MissingDatasetTypeError(name) from None 

77 

78 

79class PipelineGraphTestCase(unittest.TestCase): 

80 """Tests for the `PipelineGraph` class. 

81 

82 Tests for `PipelineGraph.resolve` are mostly in 

83 `PipelineGraphResolveTestCase` later in this file. 

84 """ 

85 

86 def setUp(self) -> None: 

87 # Simple test pipeline has two tasks, 'a' and 'b', with dataset types 

88 # 'input', 'intermediate', and 'output'. There are no dimensions on 

89 # any of those. We add tasks in reverse order to better test sorting. 

90 # There is one labeled task subset, 'only_b', with just 'b' in it. 

91 # We copy the configs so the originals (the instance attributes) can 

92 # be modified and reused after the ones passed in to the graph are 

93 # frozen. 

94 self.description = "A pipeline for PipelineGraph unit tests." 

95 self.graph = PipelineGraph() 

96 self.graph.description = self.description 

97 self.b_config = DynamicTestPipelineTaskConfig() 

98 self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema") 

99 self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") 

100 self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1") 

101 self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config)) 

102 self.a_config = DynamicTestPipelineTaskConfig() 

103 self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema") 

104 self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1") 

105 self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") 

106 self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config)) 

107 self.graph.add_task_subset("only_b", ["b"]) 

108 self.subset_description = "A subset with only task B in it." 

109 self.graph.task_subsets["only_b"].description = self.subset_description 

110 self.dimensions = DimensionUniverse() 

111 self.maxDiff = None 

112 

113 def test_unresolved_accessors(self) -> None: 

114 """Test attribute accessors, iteration, and simple methods on a graph 

115 that has not had `PipelineGraph.resolve` called on it. 

116 """ 

117 self.check_base_accessors(self.graph) 

118 self.assertEqual( 

119 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)" 

120 ) 

121 

122 def test_sorting(self) -> None: 

123 """Test sort methods on PipelineGraph.""" 

124 self.assertFalse(self.graph.has_been_sorted) 

125 self.assertFalse(self.graph.is_sorted) 

126 self.graph.sort() 

127 self.check_sorted(self.graph) 

128 

129 def test_unresolved_xgraph_export(self) -> None: 

130 """Test exporting an unresolved PipelineGraph to networkx in various 

131 ways. 

132 """ 

133 self.check_make_xgraph(self.graph, resolved=False) 

134 self.check_make_bipartite_xgraph(self.graph, resolved=False) 

135 self.check_make_task_xgraph(self.graph, resolved=False) 

136 self.check_make_dataset_type_xgraph(self.graph, resolved=False) 

137 

138 def test_unresolved_stream_io(self) -> None: 

139 """Test round-tripping an unresolved PipelineGraph through in-memory 

140 serialization. 

141 """ 

142 stream = io.BytesIO() 

143 self.graph._write_stream(stream) 

144 stream.seek(0) 

145 roundtripped = PipelineGraph._read_stream(stream) 

146 self.check_make_xgraph(roundtripped, resolved=False) 

147 

148 def test_unresolved_file_io(self) -> None: 

149 """Test round-tripping an unresolved PipelineGraph through file 

150 serialization. 

151 """ 

152 with lsst.utils.tests.getTempFilePath(".json.gz") as filename: 

153 self.graph._write_uri(filename) 

154 roundtripped = PipelineGraph._read_uri(filename) 

155 self.check_make_xgraph(roundtripped, resolved=False) 

156 

157 def test_unresolved_deferred_import_io(self) -> None: 

158 """Test round-tripping an unresolved PipelineGraph through 

159 serialization, without immediately importing tasks on read. 

160 """ 

161 stream = io.BytesIO() 

162 self.graph._write_stream(stream) 

163 stream.seek(0) 

164 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) 

165 self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) 

166 # Check that we can still resolve the graph without importing tasks. 

167 roundtripped.resolve(MockRegistry(self.dimensions, {})) 

168 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) 

169 roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES) 

170 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) 

171 

172 def test_resolved_accessors(self) -> None: 

173 """Test attribute accessors, iteration, and simple methods on a graph 

174 that has had `PipelineGraph.resolve` called on it. 

175 

176 This includes the accessors available on unresolved graphs as well as 

177 new ones, and we expect the resolved graph to be sorted as well. 

178 """ 

179 self.graph.resolve(MockRegistry(self.dimensions, {})) 

180 self.check_base_accessors(self.graph) 

181 self.check_sorted(self.graph) 

182 self.assertEqual( 

183 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})" 

184 ) 

185 self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty) 

186 self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})") 

187 self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1")) 

188 self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty) 

189 self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict") 

190 self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict") 

191 

192 def test_resolved_xgraph_export(self) -> None: 

193 """Test exporting a resolved PipelineGraph to networkx in various 

194 ways. 

195 """ 

196 self.graph.resolve(MockRegistry(self.dimensions, {})) 

197 self.check_make_xgraph(self.graph, resolved=True) 

198 self.check_make_bipartite_xgraph(self.graph, resolved=True) 

199 self.check_make_task_xgraph(self.graph, resolved=True) 

200 self.check_make_dataset_type_xgraph(self.graph, resolved=True) 

201 

202 def test_resolved_stream_io(self) -> None: 

203 """Test round-tripping a resolved PipelineGraph through in-memory 

204 serialization. 

205 """ 

206 self.graph.resolve(MockRegistry(self.dimensions, {})) 

207 stream = io.BytesIO() 

208 self.graph._write_stream(stream) 

209 stream.seek(0) 

210 roundtripped = PipelineGraph._read_stream(stream) 

211 self.check_make_xgraph(roundtripped, resolved=True) 

212 

213 def test_resolved_file_io(self) -> None: 

214 """Test round-tripping a resolved PipelineGraph through file 

215 serialization. 

216 """ 

217 self.graph.resolve(MockRegistry(self.dimensions, {})) 

218 with lsst.utils.tests.getTempFilePath(".json.gz") as filename: 

219 self.graph._write_uri(filename) 

220 roundtripped = PipelineGraph._read_uri(filename) 

221 self.check_make_xgraph(roundtripped, resolved=True) 

222 

223 def test_resolved_deferred_import_io(self) -> None: 

224 """Test round-tripping a resolved PipelineGraph through serialization, 

225 without immediately importing tasks on read. 

226 """ 

227 self.graph.resolve(MockRegistry(self.dimensions, {})) 

228 stream = io.BytesIO() 

229 self.graph._write_stream(stream) 

230 stream.seek(0) 

231 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) 

232 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) 

233 roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES) 

234 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) 

235 

236 def test_unresolved_copies(self) -> None: 

237 """Test making copies of an unresolved PipelineGraph.""" 

238 copy1 = self.graph.copy() 

239 self.assertIsNot(copy1, self.graph) 

240 self.check_make_xgraph(copy1, resolved=False) 

241 copy2 = copy.copy(self.graph) 

242 self.assertIsNot(copy2, self.graph) 

243 self.check_make_xgraph(copy2, resolved=False) 

244 copy3 = copy.deepcopy(self.graph) 

245 self.assertIsNot(copy3, self.graph) 

246 self.check_make_xgraph(copy3, resolved=False) 

247 

248 def test_resolved_copies(self) -> None: 

249 """Test making copies of a resolved PipelineGraph.""" 

250 self.graph.resolve(MockRegistry(self.dimensions, {})) 

251 copy1 = self.graph.copy() 

252 self.assertIsNot(copy1, self.graph) 

253 self.check_make_xgraph(copy1, resolved=True) 

254 copy2 = copy.copy(self.graph) 

255 self.assertIsNot(copy2, self.graph) 

256 self.check_make_xgraph(copy2, resolved=True) 

257 copy3 = copy.deepcopy(self.graph) 

258 self.assertIsNot(copy3, self.graph) 

259 self.check_make_xgraph(copy3, resolved=True) 

260 

261 def check_base_accessors(self, graph: PipelineGraph) -> None: 

262 """Run parameterized tests that check attribute access, iteration, and 

263 simple methods. 

264 

265 The given graph must be unchanged from the one defined in `setUp`, 

266 other than sorting. 

267 """ 

268 self.assertEqual(graph.description, self.description) 

269 self.assertEqual(graph.tasks.keys(), {"a", "b"}) 

270 self.assertEqual( 

271 graph.dataset_types.keys(), 

272 { 

273 "schema", 

274 "input_1", 

275 "intermediate_1", 

276 "output_1", 

277 "a_config", 

278 "a_log", 

279 "a_metadata", 

280 "b_config", 

281 "b_log", 

282 "b_metadata", 

283 }, 

284 ) 

285 self.assertEqual(graph.task_subsets.keys(), {"only_b"}) 

286 self.assertEqual( 

287 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)}, 

288 { 

289 ( 

290 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

291 NodeKey(NodeType.TASK, "a"), 

292 "input_1 -> a (input1)", 

293 ), 

294 ( 

295 NodeKey(NodeType.TASK, "a"), 

296 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

297 "a -> intermediate_1 (output1)", 

298 ), 

299 ( 

300 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

301 NodeKey(NodeType.TASK, "b"), 

302 "intermediate_1 -> b (input1)", 

303 ), 

304 ( 

305 NodeKey(NodeType.TASK, "b"), 

306 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

307 "b -> output_1 (output1)", 

308 ), 

309 (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"), 

310 ( 

311 NodeKey(NodeType.TASK, "a"), 

312 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

313 "a -> a_metadata (_metadata)", 

314 ), 

315 (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"), 

316 ( 

317 NodeKey(NodeType.TASK, "b"), 

318 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

319 "b -> b_metadata (_metadata)", 

320 ), 

321 }, 

322 ) 

323 self.assertEqual( 

324 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)}, 

325 { 

326 ( 

327 NodeKey(NodeType.TASK_INIT, "a"), 

328 NodeKey(NodeType.DATASET_TYPE, "schema"), 

329 "a -> schema (out_schema)", 

330 ), 

331 ( 

332 NodeKey(NodeType.DATASET_TYPE, "schema"), 

333 NodeKey(NodeType.TASK_INIT, "b"), 

334 "schema -> b (in_schema)", 

335 ), 

336 ( 

337 NodeKey(NodeType.TASK_INIT, "a"), 

338 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

339 "a -> a_config (_config)", 

340 ), 

341 ( 

342 NodeKey(NodeType.TASK_INIT, "b"), 

343 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

344 "b -> b_config (_config)", 

345 ), 

346 }, 

347 ) 

348 self.assertEqual( 

349 {(node_type, name) for node_type, name, _ in graph.iter_nodes()}, 

350 { 

351 NodeKey(NodeType.TASK, "a"), 

352 NodeKey(NodeType.TASK, "b"), 

353 NodeKey(NodeType.TASK_INIT, "a"), 

354 NodeKey(NodeType.TASK_INIT, "b"), 

355 NodeKey(NodeType.DATASET_TYPE, "schema"), 

356 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

357 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

358 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

359 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

360 NodeKey(NodeType.DATASET_TYPE, "a_log"), 

361 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

362 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

363 NodeKey(NodeType.DATASET_TYPE, "b_log"), 

364 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

365 }, 

366 ) 

367 self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"}) 

368 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"}) 

369 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"}) 

370 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set()) 

371 self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"}) 

372 self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"}) 

373 self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set()) 

374 

375 self.assertIsNone(graph.producing_edge_of("input_1")) 

376 self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a") 

377 self.assertEqual(graph.producing_edge_of("output_1").task_label, "b") 

378 self.assertIsNone(graph.producer_of("input_1")) 

379 self.assertEqual(graph.producer_of("intermediate_1").label, "a") 

380 self.assertEqual(graph.producer_of("output_1").label, "b") 

381 

382 self.assertEqual(graph.inputs_of("a").keys(), {"input_1"}) 

383 self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"}) 

384 self.assertEqual(graph.inputs_of("a", init=True).keys(), set()) 

385 self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"}) 

386 self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"}) 

387 self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"}) 

388 self.assertEqual( 

389 graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"} 

390 ) 

391 self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"}) 

392 self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"}) 

393 self.assertEqual( 

394 graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"} 

395 ) 

396 self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"}) 

397 self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set()) 

398 

399 self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) 

400 self.assertEqual( 

401 repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}" 

402 ) 

403 

404 def check_sorted(self, graph: PipelineGraph) -> None: 

405 """Run a battery of tests on a PipelineGraph that must be 

406 deterministically sorted. 

407 

408 The given graph must be unchanged from the one defined in `setUp`, 

409 other than sorting. 

410 """ 

411 self.assertTrue(graph.has_been_sorted) 

412 self.assertTrue(graph.is_sorted) 

413 self.assertEqual( 

414 [(node_type, name) for node_type, name, _ in graph.iter_nodes()], 

415 [ 

416 # We only advertise that the order is topological and 

417 # deterministic, so this test is slightly over-specified; there 

418 # are other orders that are consistent with our guarantees. 

419 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

420 NodeKey(NodeType.TASK_INIT, "a"), 

421 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

422 NodeKey(NodeType.DATASET_TYPE, "schema"), 

423 NodeKey(NodeType.TASK_INIT, "b"), 

424 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

425 NodeKey(NodeType.TASK, "a"), 

426 NodeKey(NodeType.DATASET_TYPE, "a_log"), 

427 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

428 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

429 NodeKey(NodeType.TASK, "b"), 

430 NodeKey(NodeType.DATASET_TYPE, "b_log"), 

431 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

432 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

433 ], 

434 ) 

435 # Most users should only care that the tasks and dataset types are 

436 # topologically sorted. 

437 self.assertEqual(list(graph.tasks), ["a", "b"]) 

438 self.assertEqual( 

439 list(graph.dataset_types), 

440 [ 

441 "input_1", 

442 "a_config", 

443 "schema", 

444 "b_config", 

445 "a_log", 

446 "a_metadata", 

447 "intermediate_1", 

448 "b_log", 

449 "b_metadata", 

450 "output_1", 

451 ], 

452 ) 

453 # __str__ and __repr__ of course work on unsorted mapping views, too, 

454 # but the order of elements is then nondeterministic and hard to test. 

455 self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})") 

456 self.assertEqual( 

457 repr(self.graph.dataset_types), 

458 ( 

459 "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, " 

460 "intermediate_1, b_log, b_metadata, output_1})" 

461 ), 

462 ) 

463 

464 def check_make_xgraph( 

465 self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True 

466 ) -> None: 

467 """Check that the given graph exports as expected to networkx. 

468 

469 The given graph must be unchanged from the one defined in `setUp`, 

470 other than being resolved (if ``resolved=True``) or round-tripped 

471 through serialization without tasks being imported (if 

472 ``imported_and_configured=False``). 

473 """ 

474 xgraph = graph.make_xgraph() 

475 expected_edges = ( 

476 {edge.key for edge in graph.iter_edges()} 

477 | {edge.key for edge in graph.iter_edges(init=True)} 

478 | { 

479 (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME), 

480 (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME), 

481 } 

482 ) 

483 test_edges = set(xgraph.edges) 

484 self.assertEqual(test_edges, expected_edges) 

485 expected_nodes = { 

486 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node( 

487 "a", resolved, imported_and_configured=imported_and_configured 

488 ), 

489 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node( 

490 "a", resolved, imported_and_configured=imported_and_configured 

491 ), 

492 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node( 

493 "b", resolved, imported_and_configured=imported_and_configured 

494 ), 

495 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node( 

496 "b", resolved, imported_and_configured=imported_and_configured 

497 ), 

498 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

499 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

500 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

501 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

502 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

503 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

504 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

505 "schema", resolved, is_initial_query_constraint=False 

506 ), 

507 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

508 "input_1", resolved, is_initial_query_constraint=True 

509 ), 

510 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

511 "intermediate_1", resolved, is_initial_query_constraint=False 

512 ), 

513 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

514 "output_1", resolved, is_initial_query_constraint=False 

515 ), 

516 } 

517 test_nodes = dict(xgraph.nodes.items()) 

518 self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys())) 

519 for key, expected_node in expected_nodes.items(): 

520 test_node = test_nodes[key] 

521 self.assertEqual(expected_node, test_node, key) 

522 

523 def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

524 """Check that the given graph's init-only or runtime subset exports as 

525 expected to networkx. 

526 

527 The given graph must be unchanged from the one defined in `setUp`, 

528 other than being resolved (if ``resolved=True``). 

529 """ 

530 run_xgraph = graph.make_bipartite_xgraph() 

531 self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()}) 

532 self.assertEqual( 

533 dict(run_xgraph.nodes.items()), 

534 { 

535 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), 

536 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), 

537 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

538 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

539 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

540 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

541 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

542 "input_1", resolved, is_initial_query_constraint=True 

543 ), 

544 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

545 "intermediate_1", resolved, is_initial_query_constraint=False 

546 ), 

547 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

548 "output_1", resolved, is_initial_query_constraint=False 

549 ), 

550 }, 

551 ) 

552 init_xgraph = graph.make_bipartite_xgraph( 

553 init=True, 

554 ) 

555 self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)}) 

556 self.assertEqual( 

557 dict(init_xgraph.nodes.items()), 

558 { 

559 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), 

560 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), 

561 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

562 "schema", resolved, is_initial_query_constraint=False 

563 ), 

564 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

565 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

566 }, 

567 ) 

568 

569 def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

570 """Check that the given graph's task-only projection exports as 

571 expected to networkx. 

572 

573 The given graph must be unchanged from the one defined in `setUp`, 

574 other than being resolved (if ``resolved=True``). 

575 """ 

576 run_xgraph = graph.make_task_xgraph() 

577 self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))}) 

578 self.assertEqual( 

579 dict(run_xgraph.nodes.items()), 

580 { 

581 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), 

582 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), 

583 }, 

584 ) 

585 init_xgraph = graph.make_task_xgraph( 

586 init=True, 

587 ) 

588 self.assertEqual( 

589 set(init_xgraph.edges), 

590 {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))}, 

591 ) 

592 self.assertEqual( 

593 dict(init_xgraph.nodes.items()), 

594 { 

595 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), 

596 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), 

597 }, 

598 ) 

599 

600 def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

601 """Check that the given graph's dataset-type-only projection exports as 

602 expected to networkx. 

603 

604 The given graph must be unchanged from the one defined in `setUp`, 

605 other than being resolved (if ``resolved=True``). 

606 """ 

607 run_xgraph = graph.make_dataset_type_xgraph() 

608 self.assertEqual( 

609 set(run_xgraph.edges), 

610 { 

611 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")), 

612 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")), 

613 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")), 

614 ( 

615 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

616 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

617 ), 

618 (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")), 

619 ( 

620 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

621 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

622 ), 

623 }, 

624 ) 

625 self.assertEqual( 

626 dict(run_xgraph.nodes.items()), 

627 { 

628 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

629 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

630 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

631 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

632 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

633 "input_1", resolved, is_initial_query_constraint=True 

634 ), 

635 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

636 "intermediate_1", resolved, is_initial_query_constraint=False 

637 ), 

638 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

639 "output_1", resolved, is_initial_query_constraint=False 

640 ), 

641 }, 

642 ) 

643 init_xgraph = graph.make_dataset_type_xgraph(init=True) 

644 self.assertEqual( 

645 set(init_xgraph.edges), 

646 {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))}, 

647 ) 

648 self.assertEqual( 

649 dict(init_xgraph.nodes.items()), 

650 { 

651 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

652 "schema", resolved, is_initial_query_constraint=False 

653 ), 

654 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

655 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

656 }, 

657 ) 

658 

659 def get_expected_task_node( 

660 self, label: str, resolved: bool, imported_and_configured: bool = True 

661 ) -> dict[str, Any]: 

662 """Construct a networkx-export task node for comparison.""" 

663 result = self.get_expected_task_init_node( 

664 label, resolved, imported_and_configured=imported_and_configured 

665 ) 

666 if resolved: 

667 result["dimensions"] = self.dimensions.empty 

668 result["raw_dimensions"] = frozenset() 

669 return result 

670 

671 def get_expected_task_init_node( 

672 self, label: str, resolved: bool, imported_and_configured: bool = True 

673 ) -> dict[str, Any]: 

674 """Construct a networkx-export task init for comparison.""" 

675 result = { 

676 "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask", 

677 "bipartite": 1, 

678 } 

679 if imported_and_configured: 

680 result["task_class"] = DynamicTestPipelineTask 

681 result["config"] = getattr(self, f"{label}_config") 

682 return result 

683 

684 def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]: 

685 """Construct a networkx-export init-output config dataset type node for 

686 comparison. 

687 """ 

688 if not resolved: 

689 return {"bipartite": 0} 

690 else: 

691 return { 

692 "dataset_type": DatasetType( 

693 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), 

694 self.dimensions.empty, 

695 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

696 ), 

697 "is_initial_query_constraint": False, 

698 "is_prerequisite": False, 

699 "dimensions": self.dimensions.empty, 

700 "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

701 "bipartite": 0, 

702 } 

703 

704 def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]: 

705 """Construct a networkx-export output log dataset type node for 

706 comparison. 

707 """ 

708 if not resolved: 

709 return {"bipartite": 0} 

710 else: 

711 return { 

712 "dataset_type": DatasetType( 

713 acc.LOG_OUTPUT_TEMPLATE.format(label=label), 

714 self.dimensions.empty, 

715 acc.LOG_OUTPUT_STORAGE_CLASS, 

716 ), 

717 "is_initial_query_constraint": False, 

718 "is_prerequisite": False, 

719 "dimensions": self.dimensions.empty, 

720 "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS, 

721 "bipartite": 0, 

722 } 

723 

724 def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]: 

725 """Construct a networkx-export output metadata dataset type node for 

726 comparison. 

727 """ 

728 if not resolved: 

729 return {"bipartite": 0} 

730 else: 

731 return { 

732 "dataset_type": DatasetType( 

733 acc.METADATA_OUTPUT_TEMPLATE.format(label=label), 

734 self.dimensions.empty, 

735 acc.METADATA_OUTPUT_STORAGE_CLASS, 

736 ), 

737 "is_initial_query_constraint": False, 

738 "is_prerequisite": False, 

739 "dimensions": self.dimensions.empty, 

740 "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS, 

741 "bipartite": 0, 

742 } 

743 

744 def get_expected_connection_node( 

745 self, name: str, resolved: bool, *, is_initial_query_constraint: bool 

746 ) -> dict[str, Any]: 

747 """Construct a networkx-export dataset type node for comparison.""" 

748 if not resolved: 

749 return {"bipartite": 0} 

750 else: 

751 return { 

752 "dataset_type": DatasetType( 

753 name, 

754 self.dimensions.empty, 

755 get_mock_name("StructuredDataDict"), 

756 ), 

757 "is_initial_query_constraint": is_initial_query_constraint, 

758 "is_prerequisite": False, 

759 "dimensions": self.dimensions.empty, 

760 "storage_class_name": get_mock_name("StructuredDataDict"), 

761 "bipartite": 0, 

762 } 

763 

764 def test_construct_with_data_coordinate(self) -> None: 

765 """Test constructing a graph with a DataCoordinate. 

766 

767 Since this creates a graph with DimensionUniverse, all tasks added to 

768 it should have resolved dimensions, but not (yet) resolved dataset 

769 types. We use that to test a few other operations in that state. 

770 """ 

771 data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions) 

772 graph = PipelineGraph(data_id=data_id) 

773 self.assertEqual(graph.universe, self.dimensions) 

774 self.assertEqual(graph.data_id, data_id) 

775 graph.add_task("b1", DynamicTestPipelineTask, self.b_config) 

776 self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty) 

777 # Still can't group by dimensions, because the dataset types aren't 

778 # resolved. 

779 with self.assertRaises(UnresolvedGraphError): 

780 graph.group_by_dimensions() 

781 # Transferring a node from this graph to ``self.graph`` should 

782 # unresolve the dimensions. 

783 self.graph.add_task_nodes([graph.tasks["b1"]]) 

784 self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"]) 

785 self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions) 

786 # Do the opposite transfer, which should resolve dimensions. 

787 graph.add_task_nodes([self.graph.tasks["a"]]) 

788 self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"]) 

789 self.assertTrue(graph.tasks["a"].has_resolved_dimensions) 

790 

791 def test_group_by_dimensions(self) -> None: 

792 """Test PipelineGraph.group_by_dimensions.""" 

793 with self.assertRaises(UnresolvedGraphError): 

794 self.graph.group_by_dimensions() 

795 self.a_config.dimensions = ["visit"] 

796 self.a_config.outputs["output1"].dimensions = ["visit"] 

797 self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig( 

798 dataset_type_name="prereq_1", 

799 multiple=True, 

800 dimensions=["htm7"], 

801 is_calibration=True, 

802 ) 

803 self.b_config.dimensions = ["htm7"] 

804 self.b_config.inputs["input1"].dimensions = ["visit"] 

805 self.b_config.inputs["input1"].multiple = True 

806 self.b_config.outputs["output1"].dimensions = ["htm7"] 

807 self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config) 

808 self.graph.resolve(MockRegistry(self.dimensions, {})) 

809 visit_dims = self.dimensions.extract(["visit"]) 

810 htm7_dims = self.dimensions.extract(["htm7"]) 

811 expected = { 

812 self.dimensions.empty: ( 

813 {}, 

814 { 

815 "schema": self.graph.dataset_types["schema"], 

816 "input_1": self.graph.dataset_types["input_1"], 

817 "a_config": self.graph.dataset_types["a_config"], 

818 "b_config": self.graph.dataset_types["b_config"], 

819 }, 

820 ), 

821 visit_dims: ( 

822 {"a": self.graph.tasks["a"]}, 

823 { 

824 "a_log": self.graph.dataset_types["a_log"], 

825 "a_metadata": self.graph.dataset_types["a_metadata"], 

826 "intermediate_1": self.graph.dataset_types["intermediate_1"], 

827 }, 

828 ), 

829 htm7_dims: ( 

830 {"b": self.graph.tasks["b"]}, 

831 { 

832 "b_log": self.graph.dataset_types["b_log"], 

833 "b_metadata": self.graph.dataset_types["b_metadata"], 

834 "output_1": self.graph.dataset_types["output_1"], 

835 }, 

836 ), 

837 } 

838 self.assertEqual(self.graph.group_by_dimensions(), expected) 

839 expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"] 

840 self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected) 

841 

842 def test_add_and_remove(self) -> None: 

843 """Tests for adding and removing tasks and task subsets from a 

844 PipelineGraph. 

845 """ 

846 # Can't remove a task while it's still in a subset. 

847 with self.assertRaises(PipelineGraphError): 

848 self.graph.remove_tasks(["b"], drop_from_subsets=False) 

849 # ...unless you remove the subset. 

850 self.graph.remove_task_subset("only_b") 

851 self.assertFalse(self.graph.task_subsets) 

852 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False) 

853 self.assertFalse(referencing_subsets) 

854 self.assertEqual(self.graph.tasks.keys(), {"a"}) 

855 # Add that task back in. 

856 self.graph.add_task_nodes([b]) 

857 self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) 

858 # Add the subset back in. 

859 self.graph.add_task_subset("only_b", {"b"}) 

860 self.assertEqual(self.graph.task_subsets.keys(), {"only_b"}) 

861 # Resolve the graph's dataset types and task dimensions. 

862 self.graph.resolve(MockRegistry(self.dimensions, {})) 

863 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

864 self.assertTrue(self.graph.dataset_types.is_resolved("output_1")) 

865 self.assertTrue(self.graph.dataset_types.is_resolved("schema")) 

866 self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1")) 

867 # Remove the task while removing it from the subset automatically. This 

868 # should also unresolve (only) the referenced dataset types and drop 

869 # any datasets no longer attached to any task. 

870 self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) 

871 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True) 

872 self.assertEqual(referencing_subsets, {"only_b"}) 

873 self.assertEqual(self.graph.tasks.keys(), {"a"}) 

874 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

875 self.assertNotIn("output1", self.graph.dataset_types) 

876 self.assertFalse(self.graph.dataset_types.is_resolved("schema")) 

877 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) 

878 

879 def test_reconfigure(self) -> None: 

880 """Tests for PipelineGraph.reconfigure.""" 

881 self.graph.resolve(MockRegistry(self.dimensions, {})) 

882 self.b_config.outputs["output1"].storage_class = "TaskMetadata" 

883 with self.assertRaises(ValueError): 

884 # Can't check and assume together. 

885 self.graph.reconfigure_tasks( 

886 b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True 

887 ) 

888 # Check that graph is unchanged after error. 

889 self.check_base_accessors(self.graph) 

890 with self.assertRaises(EdgesChangedError): 

891 self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True) 

892 self.check_base_accessors(self.graph) 

893 # Make a change that does affect edges; this will unresolve most 

894 # dataset types. 

895 self.graph.reconfigure_tasks(b=self.b_config) 

896 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

897 self.assertFalse(self.graph.dataset_types.is_resolved("output_1")) 

898 self.assertFalse(self.graph.dataset_types.is_resolved("schema")) 

899 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) 

900 # Resolving again will pick up the new storage class 

901 self.graph.resolve(MockRegistry(self.dimensions, {})) 

902 self.assertEqual( 

903 self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata") 

904 ) 

905 

906 

907def _have_example_storage_classes() -> bool: 

908 """Check whether some storage classes work as expected. 

909 

910 Given that these have registered converters, it shouldn't actually be 

911 necessary to import be able to those types in order to determine that 

912 they're convertible, but the storage class machinery is implemented such 

913 that types that can't be imported can't be converted, and while that's 

914 inconvenient here it's totally fine in non-testing scenarios where you only 

915 care about a storage class if you can actually use it. 

916 """ 

917 getter = StorageClassFactory().getStorageClass 

918 return ( 

919 getter("ArrowTable").can_convert(getter("ArrowAstropy")) 

920 and getter("ArrowAstropy").can_convert(getter("ArrowTable")) 

921 and getter("ArrowTable").can_convert(getter("DataFrame")) 

922 and getter("DataFrame").can_convert(getter("ArrowTable")) 

923 ) 

924 

925 

926class PipelineGraphResolveTestCase(unittest.TestCase): 

927 """More extensive tests for PipelineGraph.resolve and its primate helper 

928 methods. 

929 

930 These are in a separate TestCase because they utilize a different `setUp` 

931 from the rest of the `PipelineGraph` tests. 

932 """ 

933 

934 def setUp(self) -> None: 

935 self.a_config = DynamicTestPipelineTaskConfig() 

936 self.b_config = DynamicTestPipelineTaskConfig() 

937 self.dimensions = DimensionUniverse() 

938 self.maxDiff = None 

939 

940 def make_graph(self) -> PipelineGraph: 

941 graph = PipelineGraph() 

942 graph.add_task("a", DynamicTestPipelineTask, self.a_config) 

943 graph.add_task("b", DynamicTestPipelineTask, self.b_config) 

944 return graph 

945 

946 def test_prerequisite_inconsistency(self) -> None: 

947 """Test that we raise an exception when one edge defines a dataset type 

948 as a prerequisite and another does not. 

949 

950 This test will hopefully someday go away (along with 

951 `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation 

952 algorithm becomes more flexible. 

953 """ 

954 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

955 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") 

956 graph = self.make_graph() 

957 with self.assertRaises(ConnectionTypeConsistencyError): 

958 graph.resolve(MockRegistry(self.dimensions, {})) 

959 

960 def test_prerequisite_inconsistency_reversed(self) -> None: 

961 """Same as `test_prerequisite_inconsistency`, with the order the edges 

962 are added to the graph reversed. 

963 """ 

964 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") 

965 self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

966 graph = self.make_graph() 

967 with self.assertRaises(ConnectionTypeConsistencyError): 

968 graph.resolve(MockRegistry(self.dimensions, {})) 

969 

970 def test_prerequisite_output(self) -> None: 

971 """Test that we raise an exception when one edge defines a dataset type 

972 as a prerequisite but another defines it as an output. 

973 """ 

974 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

975 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

976 graph = self.make_graph() 

977 with self.assertRaises(ConnectionTypeConsistencyError): 

978 graph.resolve(MockRegistry(self.dimensions, {})) 

979 

980 def test_skypix_missing(self) -> None: 

981 """Test that we raise an exception when one edge uses the "skypix" 

982 dimension as a placeholder but the dataset type is not registered. 

983 """ 

984 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( 

985 dataset_type_name="d", dimensions={"skypix"} 

986 ) 

987 graph = self.make_graph() 

988 with self.assertRaises(MissingDatasetTypeError): 

989 graph.resolve(MockRegistry(self.dimensions, {})) 

990 

991 def test_skypix_inconsistent(self) -> None: 

992 """Test that we raise an exception when one edge uses the "skypix" 

993 dimension as a placeholder but the rest of the dimensions are 

994 inconsistent with the registered dataset type. 

995 """ 

996 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( 

997 dataset_type_name="d", dimensions={"skypix", "visit"} 

998 ) 

999 graph = self.make_graph() 

1000 with self.assertRaises(IncompatibleDatasetTypeError): 

1001 graph.resolve( 

1002 MockRegistry( 

1003 self.dimensions, 

1004 { 

1005 "d": DatasetType( 

1006 "d", 

1007 dimensions=self.dimensions.extract(["htm7"]), 

1008 storageClass="StructuredDataDict", 

1009 ) 

1010 }, 

1011 ) 

1012 ) 

1013 with self.assertRaises(IncompatibleDatasetTypeError): 

1014 graph.resolve( 

1015 MockRegistry( 

1016 self.dimensions, 

1017 { 

1018 "d": DatasetType( 

1019 "d", 

1020 dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]), 

1021 storageClass="StructuredDataDict", 

1022 ) 

1023 }, 

1024 ) 

1025 ) 

1026 

1027 def test_duplicate_outputs(self) -> None: 

1028 """Test that we raise an exception when a dataset type node would have 

1029 two write edges. 

1030 """ 

1031 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

1032 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

1033 graph = self.make_graph() 

1034 with self.assertRaises(DuplicateOutputError): 

1035 graph.resolve(MockRegistry(self.dimensions, {})) 

1036 

1037 def test_component_of_unregistered_parent(self) -> None: 

1038 """Test that we raise an exception when a component dataset type's 

1039 parent is not registered. 

1040 """ 

1041 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") 

1042 graph = self.make_graph() 

1043 with self.assertRaises(MissingDatasetTypeError): 

1044 graph.resolve(MockRegistry(self.dimensions, {})) 

1045 

1046 def test_undefined_component(self) -> None: 

1047 """Test that we raise an exception when a component dataset type's 

1048 parent is registered, but its storage class does not have that 

1049 component. 

1050 """ 

1051 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") 

1052 graph = self.make_graph() 

1053 with self.assertRaises(IncompatibleDatasetTypeError): 

1054 graph.resolve( 

1055 MockRegistry( 

1056 self.dimensions, 

1057 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1058 ) 

1059 ) 

1060 

1061 @unittest.skipUnless( 

1062 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1063 ) 

1064 def test_bad_component_storage_class(self) -> None: 

1065 """Test that we raise an exception when a component dataset type's 

1066 parent is registered, but does not have that component. 

1067 """ 

1068 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1069 dataset_type_name="d.schema", storage_class="StructuredDataDict" 

1070 ) 

1071 graph = self.make_graph() 

1072 with self.assertRaises(IncompatibleDatasetTypeError): 

1073 graph.resolve( 

1074 MockRegistry( 

1075 self.dimensions, 

1076 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, 

1077 ) 

1078 ) 

1079 

1080 def test_input_storage_class_incompatible_with_registry(self) -> None: 

1081 """Test that we raise an exception when an input connection's storage 

1082 class is incompatible with the registry definition. 

1083 """ 

1084 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1085 dataset_type_name="d", storage_class="StructuredDataList" 

1086 ) 

1087 graph = self.make_graph() 

1088 with self.assertRaises(IncompatibleDatasetTypeError): 

1089 graph.resolve( 

1090 MockRegistry( 

1091 self.dimensions, 

1092 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1093 ) 

1094 ) 

1095 

1096 def test_output_storage_class_incompatible_with_registry(self) -> None: 

1097 """Test that we raise an exception when an output connection's storage 

1098 class is incompatible with the registry definition. 

1099 """ 

1100 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1101 dataset_type_name="d", storage_class="StructuredDataList" 

1102 ) 

1103 graph = self.make_graph() 

1104 with self.assertRaises(IncompatibleDatasetTypeError): 

1105 graph.resolve( 

1106 MockRegistry( 

1107 self.dimensions, 

1108 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1109 ) 

1110 ) 

1111 

1112 def test_input_storage_class_incompatible_with_output(self) -> None: 

1113 """Test that we raise an exception when an input connection's storage 

1114 class is incompatible with the storage class of the output connection. 

1115 """ 

1116 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1117 dataset_type_name="d", storage_class="StructuredDataDict" 

1118 ) 

1119 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1120 dataset_type_name="d", storage_class="StructuredDataList" 

1121 ) 

1122 graph = self.make_graph() 

1123 with self.assertRaises(IncompatibleDatasetTypeError): 

1124 graph.resolve(MockRegistry(self.dimensions, {})) 

1125 

1126 def test_ambiguous_storage_class(self) -> None: 

1127 """Test that we raise an exception when two input connections define 

1128 the same dataset with different storage classes (even compatible ones) 

1129 and there is no output connection or registry definition to take 

1130 precedence. 

1131 """ 

1132 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1133 dataset_type_name="d", storage_class="StructuredDataDict" 

1134 ) 

1135 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1136 dataset_type_name="d", storage_class="StructuredDataList" 

1137 ) 

1138 graph = self.make_graph() 

1139 with self.assertRaises(MissingDatasetTypeError): 

1140 graph.resolve(MockRegistry(self.dimensions, {})) 

1141 

1142 @unittest.skipUnless( 

1143 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1144 ) 

1145 def test_inputs_compatible_with_registry(self) -> None: 

1146 """Test successful resolution of a dataset type where input edges have 

1147 different but compatible storage classes and the dataset type is 

1148 already registered. 

1149 """ 

1150 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") 

1151 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1152 dataset_type_name="d", storage_class="ArrowAstropy" 

1153 ) 

1154 graph = self.make_graph() 

1155 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) 

1156 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) 

1157 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) 

1158 a_i = graph.tasks["a"].inputs["i"] 

1159 b_i = graph.tasks["b"].inputs["i"] 

1160 self.assertEqual( 

1161 a_i.adapt_dataset_type(dataset_type), 

1162 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), 

1163 ) 

1164 self.assertEqual( 

1165 b_i.adapt_dataset_type(dataset_type), 

1166 dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), 

1167 ) 

1168 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1169 ref = DatasetRef(dataset_type, data_id, run="r") 

1170 a_ref = a_i.adapt_dataset_ref(ref) 

1171 b_ref = b_i.adapt_dataset_ref(ref) 

1172 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) 

1173 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) 

1174 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1175 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1176 

1177 @unittest.skipUnless( 

1178 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1179 ) 

1180 def test_output_compatible_with_registry(self) -> None: 

1181 """Test successful resolution of a dataset type where an output edge 

1182 has a different but compatible storage class from the dataset type 

1183 already registered. 

1184 """ 

1185 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1186 dataset_type_name="d", storage_class="ArrowTable" 

1187 ) 

1188 graph = self.make_graph() 

1189 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) 

1190 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) 

1191 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) 

1192 a_o = graph.tasks["a"].outputs["o"] 

1193 self.assertEqual( 

1194 a_o.adapt_dataset_type(dataset_type), 

1195 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), 

1196 ) 

1197 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1198 ref = DatasetRef(dataset_type, data_id, run="r") 

1199 a_ref = a_o.adapt_dataset_ref(ref) 

1200 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) 

1201 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1202 

1203 @unittest.skipUnless( 

1204 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1205 ) 

1206 def test_inputs_compatible_with_output(self) -> None: 

1207 """Test successful resolution of a dataset type where an input edge has 

1208 a different but compatible storage class from the output edge, and 

1209 the dataset type is not registered. 

1210 """ 

1211 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1212 dataset_type_name="d", storage_class="ArrowTable" 

1213 ) 

1214 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1215 dataset_type_name="d", storage_class="ArrowAstropy" 

1216 ) 

1217 graph = self.make_graph() 

1218 a_o = graph.tasks["a"].outputs["o"] 

1219 b_i = graph.tasks["b"].inputs["i"] 

1220 graph.resolve(MockRegistry(self.dimensions, {})) 

1221 self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable")) 

1222 self.assertEqual( 

1223 a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type), 

1224 graph.dataset_types["d"].dataset_type, 

1225 ) 

1226 self.assertEqual( 

1227 b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type), 

1228 graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), 

1229 ) 

1230 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1231 ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r") 

1232 a_ref = a_o.adapt_dataset_ref(ref) 

1233 b_ref = b_i.adapt_dataset_ref(ref) 

1234 self.assertEqual(a_ref, ref) 

1235 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) 

1236 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1237 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1238 

1239 @unittest.skipUnless( 

1240 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1241 ) 

1242 def test_component_resolved_by_input(self) -> None: 

1243 """Test successful resolution of a component dataset type due to 

1244 another input referencing the parent dataset type. 

1245 """ 

1246 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") 

1247 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1248 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1249 ) 

1250 graph = self.make_graph() 

1251 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1252 graph.resolve(MockRegistry(self.dimensions, {})) 

1253 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1254 a_i = graph.tasks["a"].inputs["i"] 

1255 b_i = graph.tasks["b"].inputs["i"] 

1256 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1257 self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type) 

1258 self.assertEqual( 

1259 b_i.adapt_dataset_type(parent_dataset_type), 

1260 parent_dataset_type.makeComponentDatasetType("schema"), 

1261 ) 

1262 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1263 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1264 a_ref = a_i.adapt_dataset_ref(ref) 

1265 b_ref = b_i.adapt_dataset_ref(ref) 

1266 self.assertEqual(a_ref, ref) 

1267 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1268 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1269 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1270 

1271 @unittest.skipUnless( 

1272 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1273 ) 

1274 def test_component_resolved_by_output(self) -> None: 

1275 """Test successful resolution of a component dataset type due to 

1276 an output connection referencing the parent dataset type. 

1277 """ 

1278 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1279 dataset_type_name="d", storage_class="ArrowTable" 

1280 ) 

1281 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1282 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1283 ) 

1284 graph = self.make_graph() 

1285 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1286 graph.resolve(MockRegistry(self.dimensions, {})) 

1287 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1288 a_o = graph.tasks["a"].outputs["o"] 

1289 b_i = graph.tasks["b"].inputs["i"] 

1290 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1291 self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type) 

1292 self.assertEqual( 

1293 b_i.adapt_dataset_type(parent_dataset_type), 

1294 parent_dataset_type.makeComponentDatasetType("schema"), 

1295 ) 

1296 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1297 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1298 a_ref = a_o.adapt_dataset_ref(ref) 

1299 b_ref = b_i.adapt_dataset_ref(ref) 

1300 self.assertEqual(a_ref, ref) 

1301 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1302 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1303 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1304 

1305 @unittest.skipUnless( 

1306 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1307 ) 

1308 def test_component_resolved_by_registry(self) -> None: 

1309 """Test successful resolution of a component dataset type due to 

1310 the parent dataset type already being registered. 

1311 """ 

1312 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1313 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1314 ) 

1315 graph = self.make_graph() 

1316 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1317 graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type})) 

1318 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1319 b_i = graph.tasks["b"].inputs["i"] 

1320 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1321 self.assertEqual( 

1322 b_i.adapt_dataset_type(parent_dataset_type), 

1323 parent_dataset_type.makeComponentDatasetType("schema"), 

1324 ) 

1325 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1326 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1327 b_ref = b_i.adapt_dataset_ref(ref) 

1328 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1329 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1330 

1331 

1332if __name__ == "__main__": 

1333 lsst.utils.tests.init() 

1334 unittest.main()