Coverage for tests/test_pipeline_graph.py: 16%

537 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:31 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests of things related to the GraphBuilder class.""" 

29 

30import copy 

31import io 

32import logging 

33import pickle 

34import textwrap 

35import unittest 

36from typing import Any 

37 

38import lsst.pipe.base.automatic_connection_constants as acc 

39import lsst.utils.tests 

40from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory 

41from lsst.daf.butler.registry import MissingDatasetTypeError 

42from lsst.pipe.base.pipeline_graph import ( 

43 ConnectionTypeConsistencyError, 

44 DuplicateOutputError, 

45 Edge, 

46 EdgesChangedError, 

47 IncompatibleDatasetTypeError, 

48 NodeKey, 

49 NodeType, 

50 PipelineGraph, 

51 PipelineGraphError, 

52 TaskImportMode, 

53 UnresolvedGraphError, 

54 visualization, 

55) 

56from lsst.pipe.base.tests.mocks import ( 

57 DynamicConnectionConfig, 

58 DynamicTestPipelineTask, 

59 DynamicTestPipelineTaskConfig, 

60 get_mock_name, 

61) 

62 

63_LOG = logging.getLogger(__name__) 

64 

65 

66class MockRegistry: 

67 """A test-utility stand-in for lsst.daf.butler.Registry that just knows 

68 how to get dataset types. 

69 """ 

70 

71 def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None: 

72 self.dimensions = dimensions 

73 self._dataset_types = dataset_types 

74 

75 def getDatasetType(self, name: str) -> DatasetType: 

76 try: 

77 return self._dataset_types[name] 

78 except KeyError: 

79 raise MissingDatasetTypeError(name) from None 

80 

81 

82class PipelineGraphTestCase(unittest.TestCase): 

83 """Tests for the `PipelineGraph` class. 

84 

85 Tests for `PipelineGraph.resolve` are mostly in 

86 `PipelineGraphResolveTestCase` later in this file. 

87 """ 

88 

89 def setUp(self) -> None: 

90 # Simple test pipeline has two tasks, 'a' and 'b', with dataset types 

91 # 'input', 'intermediate', and 'output'. There are no dimensions on 

92 # any of those. We add tasks in reverse order to better test sorting. 

93 # There is one labeled task subset, 'only_b', with just 'b' in it. 

94 # We copy the configs so the originals (the instance attributes) can 

95 # be modified and reused after the ones passed in to the graph are 

96 # frozen. 

97 self.description = "A pipeline for PipelineGraph unit tests." 

98 self.graph = PipelineGraph() 

99 self.graph.description = self.description 

100 self.b_config = DynamicTestPipelineTaskConfig() 

101 self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema") 

102 self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") 

103 self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1") 

104 self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config)) 

105 self.a_config = DynamicTestPipelineTaskConfig() 

106 self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema") 

107 self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1") 

108 self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") 

109 self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config)) 

110 self.graph.add_task_subset("only_b", ["b"]) 

111 self.subset_description = "A subset with only task B in it." 

112 self.graph.task_subsets["only_b"].description = self.subset_description 

113 self.dimensions = DimensionUniverse() 

114 self.maxDiff = None 

115 

116 def test_unresolved_accessors(self) -> None: 

117 """Test attribute accessors, iteration, and simple methods on a graph 

118 that has not had `PipelineGraph.resolve` called on it. 

119 """ 

120 self.check_base_accessors(self.graph) 

121 self.assertEqual( 

122 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)" 

123 ) 

124 

125 def test_sorting(self) -> None: 

126 """Test sort methods on PipelineGraph.""" 

127 self.assertFalse(self.graph.has_been_sorted) 

128 self.assertFalse(self.graph.is_sorted) 

129 self.graph.sort() 

130 self.check_sorted(self.graph) 

131 

132 def test_unresolved_xgraph_export(self) -> None: 

133 """Test exporting an unresolved PipelineGraph to networkx in various 

134 ways. 

135 """ 

136 self.check_make_xgraph(self.graph, resolved=False) 

137 self.check_make_bipartite_xgraph(self.graph, resolved=False) 

138 self.check_make_task_xgraph(self.graph, resolved=False) 

139 self.check_make_dataset_type_xgraph(self.graph, resolved=False) 

140 

141 def test_unresolved_stream_io(self) -> None: 

142 """Test round-tripping an unresolved PipelineGraph through in-memory 

143 serialization. 

144 """ 

145 stream = io.BytesIO() 

146 self.graph._write_stream(stream) 

147 stream.seek(0) 

148 roundtripped = PipelineGraph._read_stream(stream) 

149 self.check_make_xgraph(roundtripped, resolved=False) 

150 

151 def test_unresolved_file_io(self) -> None: 

152 """Test round-tripping an unresolved PipelineGraph through file 

153 serialization. 

154 """ 

155 with lsst.utils.tests.getTempFilePath(".json.gz") as filename: 

156 self.graph._write_uri(filename) 

157 roundtripped = PipelineGraph._read_uri(filename) 

158 self.check_make_xgraph(roundtripped, resolved=False) 

159 

160 def test_unresolved_pickle(self) -> None: 

161 """Test that unresolved PipelineGraph objects can be pickled.""" 

162 self.check_make_xgraph(pickle.loads(pickle.dumps(self.graph)), resolved=False) 

163 

164 def test_unresolved_deferred_import_io(self) -> None: 

165 """Test round-tripping an unresolved PipelineGraph through 

166 serialization, without immediately importing tasks on read. 

167 """ 

168 stream = io.BytesIO() 

169 self.graph._write_stream(stream) 

170 stream.seek(0) 

171 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) 

172 self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) 

173 self.check_make_xgraph( 

174 pickle.loads(pickle.dumps(roundtripped)), resolved=False, imported_and_configured=False 

175 ) 

176 # Check that we can still resolve the graph without importing tasks. 

177 roundtripped.resolve(MockRegistry(self.dimensions, {})) 

178 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) 

179 roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES) 

180 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) 

181 

182 def test_resolved_accessors(self) -> None: 

183 """Test attribute accessors, iteration, and simple methods on a graph 

184 that has had `PipelineGraph.resolve` called on it. 

185 

186 This includes the accessors available on unresolved graphs as well as 

187 new ones, and we expect the resolved graph to be sorted as well. 

188 """ 

189 self.graph.resolve(MockRegistry(self.dimensions, {})) 

190 self.check_base_accessors(self.graph) 

191 self.check_sorted(self.graph) 

192 self.assertEqual( 

193 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})" 

194 ) 

195 self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty) 

196 self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})") 

197 self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1")) 

198 self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty) 

199 self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict") 

200 self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict") 

201 

202 def test_resolved_xgraph_export(self) -> None: 

203 """Test exporting a resolved PipelineGraph to networkx in various 

204 ways. 

205 """ 

206 self.graph.resolve(MockRegistry(self.dimensions, {})) 

207 self.check_make_xgraph(self.graph, resolved=True) 

208 self.check_make_bipartite_xgraph(self.graph, resolved=True) 

209 self.check_make_task_xgraph(self.graph, resolved=True) 

210 self.check_make_dataset_type_xgraph(self.graph, resolved=True) 

211 

212 def test_resolved_stream_io(self) -> None: 

213 """Test round-tripping a resolved PipelineGraph through in-memory 

214 serialization. 

215 """ 

216 self.graph.resolve(MockRegistry(self.dimensions, {})) 

217 stream = io.BytesIO() 

218 self.graph._write_stream(stream) 

219 stream.seek(0) 

220 roundtripped = PipelineGraph._read_stream(stream) 

221 self.check_make_xgraph(roundtripped, resolved=True) 

222 

223 def test_resolved_file_io(self) -> None: 

224 """Test round-tripping a resolved PipelineGraph through file 

225 serialization. 

226 """ 

227 self.graph.resolve(MockRegistry(self.dimensions, {})) 

228 with lsst.utils.tests.getTempFilePath(".json.gz") as filename: 

229 self.graph._write_uri(filename) 

230 roundtripped = PipelineGraph._read_uri(filename) 

231 self.check_make_xgraph(roundtripped, resolved=True) 

232 

233 def test_resolved_pickle(self) -> None: 

234 """Test that resolved PipelineGraph objects can be pickled.""" 

235 self.graph.resolve(MockRegistry(self.dimensions, {})) 

236 self.check_make_xgraph(pickle.loads(pickle.dumps(self.graph)), resolved=True) 

237 

238 def test_resolved_deferred_import_io(self) -> None: 

239 """Test round-tripping a resolved PipelineGraph through serialization, 

240 without immediately importing tasks on read. 

241 """ 

242 self.graph.resolve(MockRegistry(self.dimensions, {})) 

243 stream = io.BytesIO() 

244 self.graph._write_stream(stream) 

245 stream.seek(0) 

246 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) 

247 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) 

248 self.check_make_xgraph( 

249 pickle.loads(pickle.dumps(roundtripped)), resolved=True, imported_and_configured=False 

250 ) 

251 roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES) 

252 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) 

253 

254 def test_unresolved_copies(self) -> None: 

255 """Test making copies of an unresolved PipelineGraph.""" 

256 copy1 = self.graph.copy() 

257 self.assertIsNot(copy1, self.graph) 

258 self.check_make_xgraph(copy1, resolved=False) 

259 copy2 = copy.copy(self.graph) 

260 self.assertIsNot(copy2, self.graph) 

261 self.check_make_xgraph(copy2, resolved=False) 

262 copy3 = copy.deepcopy(self.graph) 

263 self.assertIsNot(copy3, self.graph) 

264 self.check_make_xgraph(copy3, resolved=False) 

265 

266 def test_resolved_copies(self) -> None: 

267 """Test making copies of a resolved PipelineGraph.""" 

268 self.graph.resolve(MockRegistry(self.dimensions, {})) 

269 copy1 = self.graph.copy() 

270 self.assertIsNot(copy1, self.graph) 

271 self.check_make_xgraph(copy1, resolved=True) 

272 copy2 = copy.copy(self.graph) 

273 self.assertIsNot(copy2, self.graph) 

274 self.check_make_xgraph(copy2, resolved=True) 

275 copy3 = copy.deepcopy(self.graph) 

276 self.assertIsNot(copy3, self.graph) 

277 self.check_make_xgraph(copy3, resolved=True) 

278 

279 def check_base_accessors(self, graph: PipelineGraph) -> None: 

280 """Run parameterized tests that check attribute access, iteration, and 

281 simple methods. 

282 

283 The given graph must be unchanged from the one defined in `setUp`, 

284 other than sorting. 

285 """ 

286 self.assertEqual(graph.description, self.description) 

287 self.assertEqual(graph.tasks.keys(), {"a", "b"}) 

288 self.assertEqual( 

289 graph.dataset_types.keys(), 

290 { 

291 "schema", 

292 "input_1", 

293 "intermediate_1", 

294 "output_1", 

295 "a_config", 

296 "a_log", 

297 "a_metadata", 

298 "b_config", 

299 "b_log", 

300 "b_metadata", 

301 }, 

302 ) 

303 self.assertEqual(graph.task_subsets.keys(), {"only_b"}) 

304 self.assertEqual( 

305 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)}, 

306 { 

307 ( 

308 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

309 NodeKey(NodeType.TASK, "a"), 

310 "input_1 -> a (input1)", 

311 ), 

312 ( 

313 NodeKey(NodeType.TASK, "a"), 

314 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

315 "a -> intermediate_1 (output1)", 

316 ), 

317 ( 

318 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

319 NodeKey(NodeType.TASK, "b"), 

320 "intermediate_1 -> b (input1)", 

321 ), 

322 ( 

323 NodeKey(NodeType.TASK, "b"), 

324 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

325 "b -> output_1 (output1)", 

326 ), 

327 (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"), 

328 ( 

329 NodeKey(NodeType.TASK, "a"), 

330 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

331 "a -> a_metadata (_metadata)", 

332 ), 

333 (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"), 

334 ( 

335 NodeKey(NodeType.TASK, "b"), 

336 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

337 "b -> b_metadata (_metadata)", 

338 ), 

339 }, 

340 ) 

341 self.assertEqual( 

342 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)}, 

343 { 

344 ( 

345 NodeKey(NodeType.TASK_INIT, "a"), 

346 NodeKey(NodeType.DATASET_TYPE, "schema"), 

347 "a -> schema (out_schema)", 

348 ), 

349 ( 

350 NodeKey(NodeType.DATASET_TYPE, "schema"), 

351 NodeKey(NodeType.TASK_INIT, "b"), 

352 "schema -> b (in_schema)", 

353 ), 

354 ( 

355 NodeKey(NodeType.TASK_INIT, "a"), 

356 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

357 "a -> a_config (_config)", 

358 ), 

359 ( 

360 NodeKey(NodeType.TASK_INIT, "b"), 

361 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

362 "b -> b_config (_config)", 

363 ), 

364 }, 

365 ) 

366 self.assertEqual( 

367 {(node_type, name) for node_type, name, _ in graph.iter_nodes()}, 

368 { 

369 NodeKey(NodeType.TASK, "a"), 

370 NodeKey(NodeType.TASK, "b"), 

371 NodeKey(NodeType.TASK_INIT, "a"), 

372 NodeKey(NodeType.TASK_INIT, "b"), 

373 NodeKey(NodeType.DATASET_TYPE, "schema"), 

374 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

375 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

376 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

377 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

378 NodeKey(NodeType.DATASET_TYPE, "a_log"), 

379 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

380 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

381 NodeKey(NodeType.DATASET_TYPE, "b_log"), 

382 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

383 }, 

384 ) 

385 self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"}) 

386 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"}) 

387 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"}) 

388 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set()) 

389 self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"}) 

390 self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"}) 

391 self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set()) 

392 

393 self.assertIsNone(graph.producing_edge_of("input_1")) 

394 self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a") 

395 self.assertEqual(graph.producing_edge_of("output_1").task_label, "b") 

396 self.assertIsNone(graph.producer_of("input_1")) 

397 self.assertEqual(graph.producer_of("intermediate_1").label, "a") 

398 self.assertEqual(graph.producer_of("output_1").label, "b") 

399 

400 self.assertEqual(graph.inputs_of("a").keys(), {"input_1"}) 

401 self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"}) 

402 self.assertEqual(graph.inputs_of("a", init=True).keys(), set()) 

403 self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"}) 

404 self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"}) 

405 self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"}) 

406 self.assertEqual( 

407 graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"} 

408 ) 

409 self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"}) 

410 self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"}) 

411 self.assertEqual( 

412 graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"} 

413 ) 

414 self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"}) 

415 self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set()) 

416 

417 self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) 

418 self.assertEqual( 

419 repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}" 

420 ) 

421 

422 def check_sorted(self, graph: PipelineGraph) -> None: 

423 """Run a battery of tests on a PipelineGraph that must be 

424 deterministically sorted. 

425 

426 The given graph must be unchanged from the one defined in `setUp`, 

427 other than sorting. 

428 """ 

429 self.assertTrue(graph.has_been_sorted) 

430 self.assertTrue(graph.is_sorted) 

431 self.assertEqual( 

432 [(node_type, name) for node_type, name, _ in graph.iter_nodes()], 

433 [ 

434 # We only advertise that the order is topological and 

435 # deterministic, so this test is slightly over-specified; there 

436 # are other orders that are consistent with our guarantees. 

437 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

438 NodeKey(NodeType.TASK_INIT, "a"), 

439 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

440 NodeKey(NodeType.DATASET_TYPE, "schema"), 

441 NodeKey(NodeType.TASK_INIT, "b"), 

442 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

443 NodeKey(NodeType.TASK, "a"), 

444 NodeKey(NodeType.DATASET_TYPE, "a_log"), 

445 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

446 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

447 NodeKey(NodeType.TASK, "b"), 

448 NodeKey(NodeType.DATASET_TYPE, "b_log"), 

449 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

450 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

451 ], 

452 ) 

453 # Most users should only care that the tasks and dataset types are 

454 # topologically sorted. 

455 self.assertEqual(list(graph.tasks), ["a", "b"]) 

456 self.assertEqual( 

457 list(graph.dataset_types), 

458 [ 

459 "input_1", 

460 "a_config", 

461 "schema", 

462 "b_config", 

463 "a_log", 

464 "a_metadata", 

465 "intermediate_1", 

466 "b_log", 

467 "b_metadata", 

468 "output_1", 

469 ], 

470 ) 

471 # __str__ and __repr__ of course work on unsorted mapping views, too, 

472 # but the order of elements is then nondeterministic and hard to test. 

473 self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})") 

474 self.assertEqual( 

475 repr(self.graph.dataset_types), 

476 ( 

477 "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, " 

478 "intermediate_1, b_log, b_metadata, output_1})" 

479 ), 

480 ) 

481 

482 def check_make_xgraph( 

483 self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True 

484 ) -> None: 

485 """Check that the given graph exports as expected to networkx. 

486 

487 The given graph must be unchanged from the one defined in `setUp`, 

488 other than being resolved (if ``resolved=True``) or round-tripped 

489 through serialization without tasks being imported (if 

490 ``imported_and_configured=False``). 

491 """ 

492 xgraph = graph.make_xgraph() 

493 expected_edges = ( 

494 {edge.key for edge in graph.iter_edges()} 

495 | {edge.key for edge in graph.iter_edges(init=True)} 

496 | { 

497 (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME), 

498 (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME), 

499 } 

500 ) 

501 test_edges = set(xgraph.edges) 

502 self.assertEqual(test_edges, expected_edges) 

503 expected_nodes = { 

504 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node( 

505 "a", resolved, imported_and_configured=imported_and_configured 

506 ), 

507 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node( 

508 "a", resolved, imported_and_configured=imported_and_configured 

509 ), 

510 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node( 

511 "b", resolved, imported_and_configured=imported_and_configured 

512 ), 

513 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node( 

514 "b", resolved, imported_and_configured=imported_and_configured 

515 ), 

516 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

517 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

518 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

519 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

520 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

521 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

522 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

523 "schema", resolved, is_initial_query_constraint=False 

524 ), 

525 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

526 "input_1", resolved, is_initial_query_constraint=True 

527 ), 

528 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

529 "intermediate_1", resolved, is_initial_query_constraint=False 

530 ), 

531 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

532 "output_1", resolved, is_initial_query_constraint=False 

533 ), 

534 } 

535 test_nodes = dict(xgraph.nodes.items()) 

536 self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys())) 

537 for key, expected_node in expected_nodes.items(): 

538 test_node = test_nodes[key] 

539 self.assertEqual(expected_node, test_node, key) 

540 

541 def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

542 """Check that the given graph's init-only or runtime subset exports as 

543 expected to networkx. 

544 

545 The given graph must be unchanged from the one defined in `setUp`, 

546 other than being resolved (if ``resolved=True``). 

547 """ 

548 run_xgraph = graph.make_bipartite_xgraph() 

549 self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()}) 

550 self.assertEqual( 

551 dict(run_xgraph.nodes.items()), 

552 { 

553 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), 

554 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), 

555 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

556 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

557 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

558 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

559 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

560 "input_1", resolved, is_initial_query_constraint=True 

561 ), 

562 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

563 "intermediate_1", resolved, is_initial_query_constraint=False 

564 ), 

565 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

566 "output_1", resolved, is_initial_query_constraint=False 

567 ), 

568 }, 

569 ) 

570 init_xgraph = graph.make_bipartite_xgraph( 

571 init=True, 

572 ) 

573 self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)}) 

574 self.assertEqual( 

575 dict(init_xgraph.nodes.items()), 

576 { 

577 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), 

578 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), 

579 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

580 "schema", resolved, is_initial_query_constraint=False 

581 ), 

582 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

583 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

584 }, 

585 ) 

586 

587 def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

588 """Check that the given graph's task-only projection exports as 

589 expected to networkx. 

590 

591 The given graph must be unchanged from the one defined in `setUp`, 

592 other than being resolved (if ``resolved=True``). 

593 """ 

594 run_xgraph = graph.make_task_xgraph() 

595 self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))}) 

596 self.assertEqual( 

597 dict(run_xgraph.nodes.items()), 

598 { 

599 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), 

600 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), 

601 }, 

602 ) 

603 init_xgraph = graph.make_task_xgraph( 

604 init=True, 

605 ) 

606 self.assertEqual( 

607 set(init_xgraph.edges), 

608 {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))}, 

609 ) 

610 self.assertEqual( 

611 dict(init_xgraph.nodes.items()), 

612 { 

613 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), 

614 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), 

615 }, 

616 ) 

617 

618 def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

619 """Check that the given graph's dataset-type-only projection exports as 

620 expected to networkx. 

621 

622 The given graph must be unchanged from the one defined in `setUp`, 

623 other than being resolved (if ``resolved=True``). 

624 """ 

625 run_xgraph = graph.make_dataset_type_xgraph() 

626 self.assertEqual( 

627 set(run_xgraph.edges), 

628 { 

629 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")), 

630 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")), 

631 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")), 

632 ( 

633 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

634 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

635 ), 

636 (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")), 

637 ( 

638 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

639 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

640 ), 

641 }, 

642 ) 

643 self.assertEqual( 

644 dict(run_xgraph.nodes.items()), 

645 { 

646 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

647 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

648 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

649 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

650 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

651 "input_1", resolved, is_initial_query_constraint=True 

652 ), 

653 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

654 "intermediate_1", resolved, is_initial_query_constraint=False 

655 ), 

656 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

657 "output_1", resolved, is_initial_query_constraint=False 

658 ), 

659 }, 

660 ) 

661 init_xgraph = graph.make_dataset_type_xgraph(init=True) 

662 self.assertEqual( 

663 set(init_xgraph.edges), 

664 {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))}, 

665 ) 

666 self.assertEqual( 

667 dict(init_xgraph.nodes.items()), 

668 { 

669 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

670 "schema", resolved, is_initial_query_constraint=False 

671 ), 

672 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

673 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

674 }, 

675 ) 

676 

677 def get_expected_task_node( 

678 self, label: str, resolved: bool, imported_and_configured: bool = True 

679 ) -> dict[str, Any]: 

680 """Construct a networkx-export task node for comparison.""" 

681 result = self.get_expected_task_init_node( 

682 label, resolved, imported_and_configured=imported_and_configured 

683 ) 

684 if resolved: 

685 result["dimensions"] = self.dimensions.empty 

686 result["raw_dimensions"] = frozenset() 

687 return result 

688 

689 def get_expected_task_init_node( 

690 self, label: str, resolved: bool, imported_and_configured: bool = True 

691 ) -> dict[str, Any]: 

692 """Construct a networkx-export task init for comparison.""" 

693 result = { 

694 "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask", 

695 "bipartite": 1, 

696 } 

697 if imported_and_configured: 

698 result["task_class"] = DynamicTestPipelineTask 

699 result["config"] = getattr(self, f"{label}_config") 

700 return result 

701 

702 def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]: 

703 """Construct a networkx-export init-output config dataset type node for 

704 comparison. 

705 """ 

706 if not resolved: 

707 return {"bipartite": 0} 

708 else: 

709 return { 

710 "dataset_type": DatasetType( 

711 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), 

712 self.dimensions.empty, 

713 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

714 ), 

715 "is_initial_query_constraint": False, 

716 "is_prerequisite": False, 

717 "dimensions": self.dimensions.empty, 

718 "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

719 "bipartite": 0, 

720 } 

721 

722 def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]: 

723 """Construct a networkx-export output log dataset type node for 

724 comparison. 

725 """ 

726 if not resolved: 

727 return {"bipartite": 0} 

728 else: 

729 return { 

730 "dataset_type": DatasetType( 

731 acc.LOG_OUTPUT_TEMPLATE.format(label=label), 

732 self.dimensions.empty, 

733 acc.LOG_OUTPUT_STORAGE_CLASS, 

734 ), 

735 "is_initial_query_constraint": False, 

736 "is_prerequisite": False, 

737 "dimensions": self.dimensions.empty, 

738 "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS, 

739 "bipartite": 0, 

740 } 

741 

742 def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]: 

743 """Construct a networkx-export output metadata dataset type node for 

744 comparison. 

745 """ 

746 if not resolved: 

747 return {"bipartite": 0} 

748 else: 

749 return { 

750 "dataset_type": DatasetType( 

751 acc.METADATA_OUTPUT_TEMPLATE.format(label=label), 

752 self.dimensions.empty, 

753 acc.METADATA_OUTPUT_STORAGE_CLASS, 

754 ), 

755 "is_initial_query_constraint": False, 

756 "is_prerequisite": False, 

757 "dimensions": self.dimensions.empty, 

758 "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS, 

759 "bipartite": 0, 

760 } 

761 

762 def get_expected_connection_node( 

763 self, name: str, resolved: bool, *, is_initial_query_constraint: bool 

764 ) -> dict[str, Any]: 

765 """Construct a networkx-export dataset type node for comparison.""" 

766 if not resolved: 

767 return {"bipartite": 0} 

768 else: 

769 return { 

770 "dataset_type": DatasetType( 

771 name, 

772 self.dimensions.empty, 

773 get_mock_name("StructuredDataDict"), 

774 ), 

775 "is_initial_query_constraint": is_initial_query_constraint, 

776 "is_prerequisite": False, 

777 "dimensions": self.dimensions.empty, 

778 "storage_class_name": get_mock_name("StructuredDataDict"), 

779 "bipartite": 0, 

780 } 

781 

782 def test_construct_with_data_coordinate(self) -> None: 

783 """Test constructing a graph with a DataCoordinate. 

784 

785 Since this creates a graph with DimensionUniverse, all tasks added to 

786 it should have resolved dimensions, but not (yet) resolved dataset 

787 types. We use that to test a few other operations in that state. 

788 """ 

789 data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions) 

790 graph = PipelineGraph(data_id=data_id) 

791 self.assertEqual(graph.universe, self.dimensions) 

792 self.assertEqual(graph.data_id, data_id) 

793 graph.add_task("b1", DynamicTestPipelineTask, self.b_config) 

794 self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty) 

795 # Still can't group by dimensions, because the dataset types aren't 

796 # resolved. 

797 with self.assertRaises(UnresolvedGraphError): 

798 graph.group_by_dimensions() 

799 # Transferring a node from this graph to ``self.graph`` should 

800 # unresolve the dimensions. 

801 self.graph.add_task_nodes([graph.tasks["b1"]]) 

802 self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"]) 

803 self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions) 

804 # Do the opposite transfer, which should resolve dimensions. 

805 graph.add_task_nodes([self.graph.tasks["a"]]) 

806 self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"]) 

807 self.assertTrue(graph.tasks["a"].has_resolved_dimensions) 

808 

809 def test_group_by_dimensions(self) -> None: 

810 """Test PipelineGraph.group_by_dimensions.""" 

811 with self.assertRaises(UnresolvedGraphError): 

812 self.graph.group_by_dimensions() 

813 self.a_config.dimensions = ["visit"] 

814 self.a_config.outputs["output1"].dimensions = ["visit"] 

815 self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig( 

816 dataset_type_name="prereq_1", 

817 multiple=True, 

818 dimensions=["htm7"], 

819 is_calibration=True, 

820 ) 

821 self.b_config.dimensions = ["htm7"] 

822 self.b_config.inputs["input1"].dimensions = ["visit"] 

823 self.b_config.inputs["input1"].multiple = True 

824 self.b_config.outputs["output1"].dimensions = ["htm7"] 

825 self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config) 

826 self.graph.resolve(MockRegistry(self.dimensions, {})) 

827 visit_dims = self.dimensions.conform(["visit"]) 

828 htm7_dims = self.dimensions.conform(["htm7"]) 

829 expected = { 

830 self.dimensions.empty.as_group(): ( 

831 {}, 

832 { 

833 "schema": self.graph.dataset_types["schema"], 

834 "input_1": self.graph.dataset_types["input_1"], 

835 "a_config": self.graph.dataset_types["a_config"], 

836 "b_config": self.graph.dataset_types["b_config"], 

837 }, 

838 ), 

839 visit_dims: ( 

840 {"a": self.graph.tasks["a"]}, 

841 { 

842 "a_log": self.graph.dataset_types["a_log"], 

843 "a_metadata": self.graph.dataset_types["a_metadata"], 

844 "intermediate_1": self.graph.dataset_types["intermediate_1"], 

845 }, 

846 ), 

847 htm7_dims: ( 

848 {"b": self.graph.tasks["b"]}, 

849 { 

850 "b_log": self.graph.dataset_types["b_log"], 

851 "b_metadata": self.graph.dataset_types["b_metadata"], 

852 "output_1": self.graph.dataset_types["output_1"], 

853 }, 

854 ), 

855 } 

856 self.assertEqual(self.graph.group_by_dimensions(), expected) 

857 expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"] 

858 self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected) 

859 

860 def test_add_and_remove(self) -> None: 

861 """Tests for adding and removing tasks and task subsets from a 

862 PipelineGraph. 

863 """ 

864 original = self.graph.copy() 

865 # Can't remove a task while it's still in a subset. 

866 with self.assertRaises(PipelineGraphError): 

867 self.graph.remove_tasks(["b"], drop_from_subsets=False) 

868 self.assertEqual(original.diff_tasks(self.graph), []) 

869 # ...unless you remove the subset. 

870 self.graph.remove_task_subset("only_b") 

871 self.assertFalse(self.graph.task_subsets) 

872 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False) 

873 self.assertFalse(referencing_subsets) 

874 self.assertEqual(self.graph.tasks.keys(), {"a"}) 

875 self.assertEqual( 

876 original.diff_tasks(self.graph), 

877 ["Pipelines have different tasks: A & ~B = ['b'], B & ~A = []."], 

878 ) 

879 # Add that task back in. 

880 self.graph.add_task_nodes([b]) 

881 self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) 

882 # Add the subset back in. 

883 self.graph.add_task_subset("only_b", {"b"}) 

884 self.assertEqual(self.graph.task_subsets.keys(), {"only_b"}) 

885 # Resolve the graph's dataset types and task dimensions. 

886 self.graph.resolve(MockRegistry(self.dimensions, {})) 

887 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

888 self.assertTrue(self.graph.dataset_types.is_resolved("output_1")) 

889 self.assertTrue(self.graph.dataset_types.is_resolved("schema")) 

890 self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1")) 

891 # Remove the task while removing it from the subset automatically. This 

892 # should also unresolve (only) the referenced dataset types and drop 

893 # any datasets no longer attached to any task. 

894 self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) 

895 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True) 

896 self.assertEqual(referencing_subsets, {"only_b"}) 

897 self.assertEqual(self.graph.tasks.keys(), {"a"}) 

898 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

899 self.assertNotIn("output1", self.graph.dataset_types) 

900 self.assertFalse(self.graph.dataset_types.is_resolved("schema")) 

901 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) 

902 

903 def test_reconfigure(self) -> None: 

904 """Tests for PipelineGraph.reconfigure.""" 

905 original = self.graph.copy() 

906 self.graph.resolve(MockRegistry(self.dimensions, {})) 

907 self.b_config.outputs["output1"].storage_class = "TaskMetadata" 

908 with self.assertRaises(ValueError): 

909 # Can't check and assume together. 

910 self.graph.reconfigure_tasks( 

911 b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True 

912 ) 

913 # Check that graph is unchanged after error. 

914 self.check_base_accessors(self.graph) 

915 with self.assertRaises(EdgesChangedError): 

916 self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True) 

917 self.check_base_accessors(self.graph) 

918 self.assertEqual(original.diff_tasks(self.graph), []) 

919 # Make a change that does affect edges; this will unresolve most 

920 # dataset types. 

921 self.graph.reconfigure_tasks(b=self.b_config) 

922 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

923 self.assertFalse(self.graph.dataset_types.is_resolved("output_1")) 

924 self.assertFalse(self.graph.dataset_types.is_resolved("schema")) 

925 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) 

926 self.assertEqual( 

927 original.diff_tasks(self.graph), 

928 [ 

929 "Output b.output1 has storage class '_mock_StructuredDataDict' in A, " 

930 "but '_mock_TaskMetadata' in B." 

931 ], 

932 ) 

933 # Resolving again will pick up the new storage class 

934 self.graph.resolve(MockRegistry(self.dimensions, {})) 

935 self.assertEqual( 

936 self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata") 

937 ) 

938 

939 def check_visualization(self, graph: PipelineGraph, expected: str, **kwargs: Any) -> None: 

940 """Run pipeline graph visualization with the given kwargs and check 

941 that the output is the given expected string. 

942 

943 Parameters 

944 ---------- 

945 graph : `lsst.pipe.base.pipeline_graph.PipelineGraph` 

946 Pipeline graph to visualize. 

947 expected : `str` 

948 Expected output of the visualization. Will be passed through 

949 `textwrap.dedent`, to allow it to be written with triple-quotes. 

950 **kwargs 

951 Forwarded to `lsst.pipe.base.pipeline_graph.visualization.show`. 

952 """ 

953 stream = io.StringIO() 

954 visualization.show(graph, stream, **kwargs) 

955 self.assertEqual(textwrap.dedent(expected), stream.getvalue()) 

956 

957 def test_unresolved_visualization(self) -> None: 

958 """Test pipeline graph text-based visualization on unresolved 

959 graphs. 

960 """ 

961 self.check_visualization( 

962 self.graph, 

963 """ 

964 ■ a 

965 

966 ■ b 

967 """, 

968 merge_input_trees=0, 

969 merge_output_trees=0, 

970 merge_intermediates=False, 

971 ) 

972 self.check_visualization( 

973 self.graph, 

974 """ 

975 ○ input_1 

976 

977 ■ a 

978 

979 ○ intermediate_1 

980 

981 ■ b 

982 

983 ○ output_1 

984 """, 

985 dataset_types=True, 

986 ) 

987 

988 def test_resolved_visualization(self) -> None: 

989 """Test pipeline graph text-based visualization on resolved graphs.""" 

990 self.graph.resolve(MockRegistry(dimensions=self.dimensions, dataset_types={})) 

991 self.check_visualization( 

992 self.graph, 

993 """ 

994 ■ a: {} DynamicTestPipelineTask 

995 

996 ■ b: {} DynamicTestPipelineTask 

997 """, 

998 task_classes="concise", 

999 merge_input_trees=0, 

1000 merge_output_trees=0, 

1001 merge_intermediates=False, 

1002 ) 

1003 self.check_visualization( 

1004 self.graph, 

1005 """ 

1006 ○ input_1: {} _mock_StructuredDataDict 

1007 

1008 ■ a: {} lsst.pipe.base.tests.mocks.DynamicTestPipelineTask 

1009 

1010 ○ intermediate_1: {} _mock_StructuredDataDict 

1011 

1012 ■ b: {} lsst.pipe.base.tests.mocks.DynamicTestPipelineTask 

1013 

1014 ○ output_1: {} _mock_StructuredDataDict 

1015 """, 

1016 task_classes="full", 

1017 dataset_types=True, 

1018 ) 

1019 

1020 

1021def _have_example_storage_classes() -> bool: 

1022 """Check whether some storage classes work as expected. 

1023 

1024 Given that these have registered converters, it shouldn't actually be 

1025 necessary to import be able to those types in order to determine that 

1026 they're convertible, but the storage class machinery is implemented such 

1027 that types that can't be imported can't be converted, and while that's 

1028 inconvenient here it's totally fine in non-testing scenarios where you only 

1029 care about a storage class if you can actually use it. 

1030 """ 

1031 getter = StorageClassFactory().getStorageClass 

1032 return ( 

1033 getter("ArrowTable").can_convert(getter("ArrowAstropy")) 

1034 and getter("ArrowAstropy").can_convert(getter("ArrowTable")) 

1035 and getter("ArrowTable").can_convert(getter("DataFrame")) 

1036 and getter("DataFrame").can_convert(getter("ArrowTable")) 

1037 ) 

1038 

1039 

1040class PipelineGraphResolveTestCase(unittest.TestCase): 

1041 """More extensive tests for PipelineGraph.resolve and its primate helper 

1042 methods. 

1043 

1044 These are in a separate TestCase because they utilize a different `setUp` 

1045 from the rest of the `PipelineGraph` tests. 

1046 """ 

1047 

1048 def setUp(self) -> None: 

1049 self.a_config = DynamicTestPipelineTaskConfig() 

1050 self.b_config = DynamicTestPipelineTaskConfig() 

1051 self.dimensions = DimensionUniverse() 

1052 self.maxDiff = None 

1053 

1054 def make_graph(self) -> PipelineGraph: 

1055 graph = PipelineGraph() 

1056 graph.add_task("a", DynamicTestPipelineTask, self.a_config) 

1057 graph.add_task("b", DynamicTestPipelineTask, self.b_config) 

1058 return graph 

1059 

1060 def test_prerequisite_inconsistency(self) -> None: 

1061 """Test that we raise an exception when one edge defines a dataset type 

1062 as a prerequisite and another does not. 

1063 

1064 This test will hopefully someday go away (along with 

1065 `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation 

1066 algorithm becomes more flexible. 

1067 """ 

1068 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

1069 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") 

1070 graph = self.make_graph() 

1071 with self.assertRaises(ConnectionTypeConsistencyError): 

1072 graph.resolve(MockRegistry(self.dimensions, {})) 

1073 

1074 def test_prerequisite_inconsistency_reversed(self) -> None: 

1075 """Same as `test_prerequisite_inconsistency`, with the order the edges 

1076 are added to the graph reversed. 

1077 """ 

1078 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") 

1079 self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

1080 graph = self.make_graph() 

1081 with self.assertRaises(ConnectionTypeConsistencyError): 

1082 graph.resolve(MockRegistry(self.dimensions, {})) 

1083 

1084 def test_prerequisite_output(self) -> None: 

1085 """Test that we raise an exception when one edge defines a dataset type 

1086 as a prerequisite but another defines it as an output. 

1087 """ 

1088 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

1089 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

1090 graph = self.make_graph() 

1091 with self.assertRaises(ConnectionTypeConsistencyError): 

1092 graph.resolve(MockRegistry(self.dimensions, {})) 

1093 

1094 def test_skypix_missing(self) -> None: 

1095 """Test that we raise an exception when one edge uses the "skypix" 

1096 dimension as a placeholder but the dataset type is not registered. 

1097 """ 

1098 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( 

1099 dataset_type_name="d", dimensions={"skypix"} 

1100 ) 

1101 graph = self.make_graph() 

1102 with self.assertRaises(MissingDatasetTypeError): 

1103 graph.resolve(MockRegistry(self.dimensions, {})) 

1104 

1105 def test_skypix_inconsistent(self) -> None: 

1106 """Test that we raise an exception when one edge uses the "skypix" 

1107 dimension as a placeholder but the rest of the dimensions are 

1108 inconsistent with the registered dataset type. 

1109 """ 

1110 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( 

1111 dataset_type_name="d", dimensions={"skypix", "visit"} 

1112 ) 

1113 graph = self.make_graph() 

1114 with self.assertRaises(IncompatibleDatasetTypeError): 

1115 graph.resolve( 

1116 MockRegistry( 

1117 self.dimensions, 

1118 { 

1119 "d": DatasetType( 

1120 "d", 

1121 dimensions=self.dimensions.conform(["htm7"]), 

1122 storageClass="StructuredDataDict", 

1123 ) 

1124 }, 

1125 ) 

1126 ) 

1127 with self.assertRaises(IncompatibleDatasetTypeError): 

1128 graph.resolve( 

1129 MockRegistry( 

1130 self.dimensions, 

1131 { 

1132 "d": DatasetType( 

1133 "d", 

1134 dimensions=self.dimensions.conform(["htm7", "visit", "skymap"]), 

1135 storageClass="StructuredDataDict", 

1136 ) 

1137 }, 

1138 ) 

1139 ) 

1140 

1141 def test_duplicate_outputs(self) -> None: 

1142 """Test that we raise an exception when a dataset type node would have 

1143 two write edges. 

1144 """ 

1145 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

1146 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

1147 graph = self.make_graph() 

1148 with self.assertRaises(DuplicateOutputError): 

1149 graph.resolve(MockRegistry(self.dimensions, {})) 

1150 

1151 def test_component_of_unregistered_parent(self) -> None: 

1152 """Test that we raise an exception when a component dataset type's 

1153 parent is not registered. 

1154 """ 

1155 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") 

1156 graph = self.make_graph() 

1157 with self.assertRaises(MissingDatasetTypeError): 

1158 graph.resolve(MockRegistry(self.dimensions, {})) 

1159 

1160 def test_undefined_component(self) -> None: 

1161 """Test that we raise an exception when a component dataset type's 

1162 parent is registered, but its storage class does not have that 

1163 component. 

1164 """ 

1165 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") 

1166 graph = self.make_graph() 

1167 with self.assertRaises(IncompatibleDatasetTypeError): 

1168 graph.resolve( 

1169 MockRegistry( 

1170 self.dimensions, 

1171 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1172 ) 

1173 ) 

1174 

1175 @unittest.skipUnless( 

1176 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1177 ) 

1178 def test_bad_component_storage_class(self) -> None: 

1179 """Test that we raise an exception when a component dataset type's 

1180 parent is registered, but does not have that component. 

1181 """ 

1182 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1183 dataset_type_name="d.schema", storage_class="StructuredDataDict" 

1184 ) 

1185 graph = self.make_graph() 

1186 with self.assertRaises(IncompatibleDatasetTypeError): 

1187 graph.resolve( 

1188 MockRegistry( 

1189 self.dimensions, 

1190 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, 

1191 ) 

1192 ) 

1193 

1194 def test_input_storage_class_incompatible_with_registry(self) -> None: 

1195 """Test that we raise an exception when an input connection's storage 

1196 class is incompatible with the registry definition. 

1197 """ 

1198 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1199 dataset_type_name="d", storage_class="StructuredDataList" 

1200 ) 

1201 graph = self.make_graph() 

1202 with self.assertRaises(IncompatibleDatasetTypeError): 

1203 graph.resolve( 

1204 MockRegistry( 

1205 self.dimensions, 

1206 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1207 ) 

1208 ) 

1209 

1210 def test_output_storage_class_incompatible_with_registry(self) -> None: 

1211 """Test that we raise an exception when an output connection's storage 

1212 class is incompatible with the registry definition. 

1213 """ 

1214 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1215 dataset_type_name="d", storage_class="StructuredDataList" 

1216 ) 

1217 graph = self.make_graph() 

1218 with self.assertRaises(IncompatibleDatasetTypeError): 

1219 graph.resolve( 

1220 MockRegistry( 

1221 self.dimensions, 

1222 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1223 ) 

1224 ) 

1225 

1226 def test_input_storage_class_incompatible_with_output(self) -> None: 

1227 """Test that we raise an exception when an input connection's storage 

1228 class is incompatible with the storage class of the output connection. 

1229 """ 

1230 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1231 dataset_type_name="d", storage_class="StructuredDataDict" 

1232 ) 

1233 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1234 dataset_type_name="d", storage_class="StructuredDataList" 

1235 ) 

1236 graph = self.make_graph() 

1237 with self.assertRaises(IncompatibleDatasetTypeError): 

1238 graph.resolve(MockRegistry(self.dimensions, {})) 

1239 

1240 def test_ambiguous_storage_class(self) -> None: 

1241 """Test that we raise an exception when two input connections define 

1242 the same dataset with different storage classes (even compatible ones) 

1243 and there is no output connection or registry definition to take 

1244 precedence. 

1245 """ 

1246 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1247 dataset_type_name="d", storage_class="StructuredDataDict" 

1248 ) 

1249 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1250 dataset_type_name="d", storage_class="StructuredDataList" 

1251 ) 

1252 graph = self.make_graph() 

1253 with self.assertRaises(MissingDatasetTypeError): 

1254 graph.resolve(MockRegistry(self.dimensions, {})) 

1255 

1256 @unittest.skipUnless( 

1257 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1258 ) 

1259 def test_inputs_compatible_with_registry(self) -> None: 

1260 """Test successful resolution of a dataset type where input edges have 

1261 different but compatible storage classes and the dataset type is 

1262 already registered. 

1263 """ 

1264 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") 

1265 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1266 dataset_type_name="d", storage_class="ArrowAstropy" 

1267 ) 

1268 graph = self.make_graph() 

1269 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) 

1270 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) 

1271 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) 

1272 a_i = graph.tasks["a"].inputs["i"] 

1273 b_i = graph.tasks["b"].inputs["i"] 

1274 self.assertEqual( 

1275 a_i.adapt_dataset_type(dataset_type), 

1276 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), 

1277 ) 

1278 self.assertEqual( 

1279 b_i.adapt_dataset_type(dataset_type), 

1280 dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), 

1281 ) 

1282 data_id = DataCoordinate.make_empty(self.dimensions) 

1283 ref = DatasetRef(dataset_type, data_id, run="r") 

1284 a_ref = a_i.adapt_dataset_ref(ref) 

1285 b_ref = b_i.adapt_dataset_ref(ref) 

1286 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) 

1287 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) 

1288 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1289 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1290 

1291 @unittest.skipUnless( 

1292 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1293 ) 

1294 def test_output_compatible_with_registry(self) -> None: 

1295 """Test successful resolution of a dataset type where an output edge 

1296 has a different but compatible storage class from the dataset type 

1297 already registered. 

1298 """ 

1299 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1300 dataset_type_name="d", storage_class="ArrowTable" 

1301 ) 

1302 graph = self.make_graph() 

1303 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) 

1304 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) 

1305 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) 

1306 a_o = graph.tasks["a"].outputs["o"] 

1307 self.assertEqual( 

1308 a_o.adapt_dataset_type(dataset_type), 

1309 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), 

1310 ) 

1311 data_id = DataCoordinate.make_empty(self.dimensions) 

1312 ref = DatasetRef(dataset_type, data_id, run="r") 

1313 a_ref = a_o.adapt_dataset_ref(ref) 

1314 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) 

1315 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1316 

1317 @unittest.skipUnless( 

1318 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1319 ) 

1320 def test_inputs_compatible_with_output(self) -> None: 

1321 """Test successful resolution of a dataset type where an input edge has 

1322 a different but compatible storage class from the output edge, and 

1323 the dataset type is not registered. 

1324 """ 

1325 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1326 dataset_type_name="d", storage_class="ArrowTable" 

1327 ) 

1328 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1329 dataset_type_name="d", storage_class="ArrowAstropy" 

1330 ) 

1331 graph = self.make_graph() 

1332 a_o = graph.tasks["a"].outputs["o"] 

1333 b_i = graph.tasks["b"].inputs["i"] 

1334 graph.resolve(MockRegistry(self.dimensions, {})) 

1335 self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable")) 

1336 self.assertEqual( 

1337 a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type), 

1338 graph.dataset_types["d"].dataset_type, 

1339 ) 

1340 self.assertEqual( 

1341 b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type), 

1342 graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), 

1343 ) 

1344 data_id = DataCoordinate.make_empty(self.dimensions) 

1345 ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r") 

1346 a_ref = a_o.adapt_dataset_ref(ref) 

1347 b_ref = b_i.adapt_dataset_ref(ref) 

1348 self.assertEqual(a_ref, ref) 

1349 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) 

1350 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1351 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1352 

1353 @unittest.skipUnless( 

1354 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1355 ) 

1356 def test_component_resolved_by_input(self) -> None: 

1357 """Test successful resolution of a component dataset type due to 

1358 another input referencing the parent dataset type. 

1359 """ 

1360 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") 

1361 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1362 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1363 ) 

1364 graph = self.make_graph() 

1365 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1366 graph.resolve(MockRegistry(self.dimensions, {})) 

1367 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1368 a_i = graph.tasks["a"].inputs["i"] 

1369 b_i = graph.tasks["b"].inputs["i"] 

1370 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1371 self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type) 

1372 self.assertEqual( 

1373 b_i.adapt_dataset_type(parent_dataset_type), 

1374 parent_dataset_type.makeComponentDatasetType("schema"), 

1375 ) 

1376 data_id = DataCoordinate.make_empty(self.dimensions) 

1377 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1378 a_ref = a_i.adapt_dataset_ref(ref) 

1379 b_ref = b_i.adapt_dataset_ref(ref) 

1380 self.assertEqual(a_ref, ref) 

1381 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1382 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1383 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1384 

1385 @unittest.skipUnless( 

1386 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1387 ) 

1388 def test_component_resolved_by_output(self) -> None: 

1389 """Test successful resolution of a component dataset type due to 

1390 an output connection referencing the parent dataset type. 

1391 """ 

1392 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1393 dataset_type_name="d", storage_class="ArrowTable" 

1394 ) 

1395 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1396 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1397 ) 

1398 graph = self.make_graph() 

1399 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1400 graph.resolve(MockRegistry(self.dimensions, {})) 

1401 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1402 a_o = graph.tasks["a"].outputs["o"] 

1403 b_i = graph.tasks["b"].inputs["i"] 

1404 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1405 self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type) 

1406 self.assertEqual( 

1407 b_i.adapt_dataset_type(parent_dataset_type), 

1408 parent_dataset_type.makeComponentDatasetType("schema"), 

1409 ) 

1410 data_id = DataCoordinate.make_empty(self.dimensions) 

1411 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1412 a_ref = a_o.adapt_dataset_ref(ref) 

1413 b_ref = b_i.adapt_dataset_ref(ref) 

1414 self.assertEqual(a_ref, ref) 

1415 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1416 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1417 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1418 

1419 @unittest.skipUnless( 

1420 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1421 ) 

1422 def test_component_resolved_by_registry(self) -> None: 

1423 """Test successful resolution of a component dataset type due to 

1424 the parent dataset type already being registered. 

1425 """ 

1426 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1427 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1428 ) 

1429 graph = self.make_graph() 

1430 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1431 graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type})) 

1432 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1433 b_i = graph.tasks["b"].inputs["i"] 

1434 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1435 self.assertEqual( 

1436 b_i.adapt_dataset_type(parent_dataset_type), 

1437 parent_dataset_type.makeComponentDatasetType("schema"), 

1438 ) 

1439 data_id = DataCoordinate.make_empty(self.dimensions) 

1440 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1441 b_ref = b_i.adapt_dataset_ref(ref) 

1442 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1443 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1444 

1445 

1446if __name__ == "__main__": 

1447 lsst.utils.tests.init() 

1448 unittest.main()