Coverage for tests/test_pipeline_graph.py: 15%

511 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of things related to the GraphBuilder class.""" 

23 

24import copy 

25import io 

26import logging 

27import unittest 

28from typing import Any 

29 

30import lsst.pipe.base.automatic_connection_constants as acc 

31import lsst.utils.tests 

32from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory 

33from lsst.daf.butler.registry import MissingDatasetTypeError 

34from lsst.pipe.base.pipeline_graph import ( 

35 ConnectionTypeConsistencyError, 

36 DuplicateOutputError, 

37 Edge, 

38 EdgesChangedError, 

39 IncompatibleDatasetTypeError, 

40 NodeKey, 

41 NodeType, 

42 PipelineGraph, 

43 PipelineGraphError, 

44 TaskImportMode, 

45 UnresolvedGraphError, 

46) 

47from lsst.pipe.base.tests.mocks import ( 

48 DynamicConnectionConfig, 

49 DynamicTestPipelineTask, 

50 DynamicTestPipelineTaskConfig, 

51 get_mock_name, 

52) 

53 

54_LOG = logging.getLogger(__name__) 

55 

56 

57class MockRegistry: 

58 """A test-utility stand-in for lsst.daf.butler.Registry that just knows 

59 how to get dataset types. 

60 """ 

61 

62 def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None: 

63 self.dimensions = dimensions 

64 self._dataset_types = dataset_types 

65 

66 def getDatasetType(self, name: str) -> DatasetType: 

67 try: 

68 return self._dataset_types[name] 

69 except KeyError: 

70 raise MissingDatasetTypeError(name) from None 

71 

72 

73class PipelineGraphTestCase(unittest.TestCase): 

74 """Tests for the `PipelineGraph` class. 

75 

76 Tests for `PipelineGraph.resolve` are mostly in 

77 `PipelineGraphResolveTestCase` later in this file. 

78 """ 

79 

80 def setUp(self) -> None: 

81 # Simple test pipeline has two tasks, 'a' and 'b', with dataset types 

82 # 'input', 'intermediate', and 'output'. There are no dimensions on 

83 # any of those. We add tasks in reverse order to better test sorting. 

84 # There is one labeled task subset, 'only_b', with just 'b' in it. 

85 # We copy the configs so the originals (the instance attributes) can 

86 # be modified and reused after the ones passed in to the graph are 

87 # frozen. 

88 self.description = "A pipeline for PipelineGraph unit tests." 

89 self.graph = PipelineGraph() 

90 self.graph.description = self.description 

91 self.b_config = DynamicTestPipelineTaskConfig() 

92 self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema") 

93 self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") 

94 self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1") 

95 self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config)) 

96 self.a_config = DynamicTestPipelineTaskConfig() 

97 self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema") 

98 self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1") 

99 self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") 

100 self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config)) 

101 self.graph.add_task_subset("only_b", ["b"]) 

102 self.subset_description = "A subset with only task B in it." 

103 self.graph.task_subsets["only_b"].description = self.subset_description 

104 self.dimensions = DimensionUniverse() 

105 self.maxDiff = None 

106 

107 def test_unresolved_accessors(self) -> None: 

108 """Test attribute accessors, iteration, and simple methods on a graph 

109 that has not had `PipelineGraph.resolve` called on it. 

110 """ 

111 self.check_base_accessors(self.graph) 

112 self.assertEqual( 

113 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)" 

114 ) 

115 

116 def test_sorting(self) -> None: 

117 """Test sort methods on PipelineGraph.""" 

118 self.assertFalse(self.graph.has_been_sorted) 

119 self.assertFalse(self.graph.is_sorted) 

120 self.graph.sort() 

121 self.check_sorted(self.graph) 

122 

123 def test_unresolved_xgraph_export(self) -> None: 

124 """Test exporting an unresolved PipelineGraph to networkx in various 

125 ways. 

126 """ 

127 self.check_make_xgraph(self.graph, resolved=False) 

128 self.check_make_bipartite_xgraph(self.graph, resolved=False) 

129 self.check_make_task_xgraph(self.graph, resolved=False) 

130 self.check_make_dataset_type_xgraph(self.graph, resolved=False) 

131 

132 def test_unresolved_stream_io(self) -> None: 

133 """Test round-tripping an unresolved PipelineGraph through in-memory 

134 serialization. 

135 """ 

136 stream = io.BytesIO() 

137 self.graph._write_stream(stream) 

138 stream.seek(0) 

139 roundtripped = PipelineGraph._read_stream(stream) 

140 self.check_make_xgraph(roundtripped, resolved=False) 

141 

142 def test_unresolved_file_io(self) -> None: 

143 """Test round-tripping an unresolved PipelineGraph through file 

144 serialization. 

145 """ 

146 with lsst.utils.tests.getTempFilePath(".json.gz") as filename: 

147 self.graph._write_uri(filename) 

148 roundtripped = PipelineGraph._read_uri(filename) 

149 self.check_make_xgraph(roundtripped, resolved=False) 

150 

151 def test_unresolved_deferred_import_io(self) -> None: 

152 """Test round-tripping an unresolved PipelineGraph through 

153 serialization, without immediately importing tasks on read. 

154 """ 

155 stream = io.BytesIO() 

156 self.graph._write_stream(stream) 

157 stream.seek(0) 

158 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) 

159 self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) 

160 # Check that we can still resolve the graph without importing tasks. 

161 roundtripped.resolve(MockRegistry(self.dimensions, {})) 

162 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) 

163 roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES) 

164 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) 

165 

166 def test_resolved_accessors(self) -> None: 

167 """Test attribute accessors, iteration, and simple methods on a graph 

168 that has had `PipelineGraph.resolve` called on it. 

169 

170 This includes the accessors available on unresolved graphs as well as 

171 new ones, and we expect the resolved graph to be sorted as well. 

172 """ 

173 self.graph.resolve(MockRegistry(self.dimensions, {})) 

174 self.check_base_accessors(self.graph) 

175 self.check_sorted(self.graph) 

176 self.assertEqual( 

177 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})" 

178 ) 

179 self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty) 

180 self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})") 

181 self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1")) 

182 self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty) 

183 self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict") 

184 self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict") 

185 

186 def test_resolved_xgraph_export(self) -> None: 

187 """Test exporting a resolved PipelineGraph to networkx in various 

188 ways. 

189 """ 

190 self.graph.resolve(MockRegistry(self.dimensions, {})) 

191 self.check_make_xgraph(self.graph, resolved=True) 

192 self.check_make_bipartite_xgraph(self.graph, resolved=True) 

193 self.check_make_task_xgraph(self.graph, resolved=True) 

194 self.check_make_dataset_type_xgraph(self.graph, resolved=True) 

195 

196 def test_resolved_stream_io(self) -> None: 

197 """Test round-tripping a resolved PipelineGraph through in-memory 

198 serialization. 

199 """ 

200 self.graph.resolve(MockRegistry(self.dimensions, {})) 

201 stream = io.BytesIO() 

202 self.graph._write_stream(stream) 

203 stream.seek(0) 

204 roundtripped = PipelineGraph._read_stream(stream) 

205 self.check_make_xgraph(roundtripped, resolved=True) 

206 

207 def test_resolved_file_io(self) -> None: 

208 """Test round-tripping a resolved PipelineGraph through file 

209 serialization. 

210 """ 

211 self.graph.resolve(MockRegistry(self.dimensions, {})) 

212 with lsst.utils.tests.getTempFilePath(".json.gz") as filename: 

213 self.graph._write_uri(filename) 

214 roundtripped = PipelineGraph._read_uri(filename) 

215 self.check_make_xgraph(roundtripped, resolved=True) 

216 

217 def test_resolved_deferred_import_io(self) -> None: 

218 """Test round-tripping a resolved PipelineGraph through serialization, 

219 without immediately importing tasks on read. 

220 """ 

221 self.graph.resolve(MockRegistry(self.dimensions, {})) 

222 stream = io.BytesIO() 

223 self.graph._write_stream(stream) 

224 stream.seek(0) 

225 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) 

226 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) 

227 roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES) 

228 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) 

229 

230 def test_unresolved_copies(self) -> None: 

231 """Test making copies of an unresolved PipelineGraph.""" 

232 copy1 = self.graph.copy() 

233 self.assertIsNot(copy1, self.graph) 

234 self.check_make_xgraph(copy1, resolved=False) 

235 copy2 = copy.copy(self.graph) 

236 self.assertIsNot(copy2, self.graph) 

237 self.check_make_xgraph(copy2, resolved=False) 

238 copy3 = copy.deepcopy(self.graph) 

239 self.assertIsNot(copy3, self.graph) 

240 self.check_make_xgraph(copy3, resolved=False) 

241 

242 def test_resolved_copies(self) -> None: 

243 """Test making copies of a resolved PipelineGraph.""" 

244 self.graph.resolve(MockRegistry(self.dimensions, {})) 

245 copy1 = self.graph.copy() 

246 self.assertIsNot(copy1, self.graph) 

247 self.check_make_xgraph(copy1, resolved=True) 

248 copy2 = copy.copy(self.graph) 

249 self.assertIsNot(copy2, self.graph) 

250 self.check_make_xgraph(copy2, resolved=True) 

251 copy3 = copy.deepcopy(self.graph) 

252 self.assertIsNot(copy3, self.graph) 

253 self.check_make_xgraph(copy3, resolved=True) 

254 

255 def check_base_accessors(self, graph: PipelineGraph) -> None: 

256 """Run parameterized tests that check attribute access, iteration, and 

257 simple methods. 

258 

259 The given graph must be unchanged from the one defined in `setUp`, 

260 other than sorting. 

261 """ 

262 self.assertEqual(graph.description, self.description) 

263 self.assertEqual(graph.tasks.keys(), {"a", "b"}) 

264 self.assertEqual( 

265 graph.dataset_types.keys(), 

266 { 

267 "schema", 

268 "input_1", 

269 "intermediate_1", 

270 "output_1", 

271 "a_config", 

272 "a_log", 

273 "a_metadata", 

274 "b_config", 

275 "b_log", 

276 "b_metadata", 

277 }, 

278 ) 

279 self.assertEqual(graph.task_subsets.keys(), {"only_b"}) 

280 self.assertEqual( 

281 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)}, 

282 { 

283 ( 

284 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

285 NodeKey(NodeType.TASK, "a"), 

286 "input_1 -> a (input1)", 

287 ), 

288 ( 

289 NodeKey(NodeType.TASK, "a"), 

290 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

291 "a -> intermediate_1 (output1)", 

292 ), 

293 ( 

294 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

295 NodeKey(NodeType.TASK, "b"), 

296 "intermediate_1 -> b (input1)", 

297 ), 

298 ( 

299 NodeKey(NodeType.TASK, "b"), 

300 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

301 "b -> output_1 (output1)", 

302 ), 

303 (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"), 

304 ( 

305 NodeKey(NodeType.TASK, "a"), 

306 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

307 "a -> a_metadata (_metadata)", 

308 ), 

309 (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"), 

310 ( 

311 NodeKey(NodeType.TASK, "b"), 

312 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

313 "b -> b_metadata (_metadata)", 

314 ), 

315 }, 

316 ) 

317 self.assertEqual( 

318 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)}, 

319 { 

320 ( 

321 NodeKey(NodeType.TASK_INIT, "a"), 

322 NodeKey(NodeType.DATASET_TYPE, "schema"), 

323 "a -> schema (out_schema)", 

324 ), 

325 ( 

326 NodeKey(NodeType.DATASET_TYPE, "schema"), 

327 NodeKey(NodeType.TASK_INIT, "b"), 

328 "schema -> b (in_schema)", 

329 ), 

330 ( 

331 NodeKey(NodeType.TASK_INIT, "a"), 

332 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

333 "a -> a_config (_config)", 

334 ), 

335 ( 

336 NodeKey(NodeType.TASK_INIT, "b"), 

337 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

338 "b -> b_config (_config)", 

339 ), 

340 }, 

341 ) 

342 self.assertEqual( 

343 {(node_type, name) for node_type, name, _ in graph.iter_nodes()}, 

344 { 

345 NodeKey(NodeType.TASK, "a"), 

346 NodeKey(NodeType.TASK, "b"), 

347 NodeKey(NodeType.TASK_INIT, "a"), 

348 NodeKey(NodeType.TASK_INIT, "b"), 

349 NodeKey(NodeType.DATASET_TYPE, "schema"), 

350 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

351 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

352 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

353 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

354 NodeKey(NodeType.DATASET_TYPE, "a_log"), 

355 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

356 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

357 NodeKey(NodeType.DATASET_TYPE, "b_log"), 

358 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

359 }, 

360 ) 

361 self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"}) 

362 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"}) 

363 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"}) 

364 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set()) 

365 self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"}) 

366 self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"}) 

367 self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set()) 

368 

369 self.assertIsNone(graph.producing_edge_of("input_1")) 

370 self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a") 

371 self.assertEqual(graph.producing_edge_of("output_1").task_label, "b") 

372 self.assertIsNone(graph.producer_of("input_1")) 

373 self.assertEqual(graph.producer_of("intermediate_1").label, "a") 

374 self.assertEqual(graph.producer_of("output_1").label, "b") 

375 

376 self.assertEqual(graph.inputs_of("a").keys(), {"input_1"}) 

377 self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"}) 

378 self.assertEqual(graph.inputs_of("a", init=True).keys(), set()) 

379 self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"}) 

380 self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"}) 

381 self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"}) 

382 self.assertEqual( 

383 graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"} 

384 ) 

385 self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"}) 

386 self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"}) 

387 self.assertEqual( 

388 graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"} 

389 ) 

390 self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"}) 

391 self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set()) 

392 

393 self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) 

394 self.assertEqual( 

395 repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}" 

396 ) 

397 

398 def check_sorted(self, graph: PipelineGraph) -> None: 

399 """Run a battery of tests on a PipelineGraph that must be 

400 deterministically sorted. 

401 

402 The given graph must be unchanged from the one defined in `setUp`, 

403 other than sorting. 

404 """ 

405 self.assertTrue(graph.has_been_sorted) 

406 self.assertTrue(graph.is_sorted) 

407 self.assertEqual( 

408 [(node_type, name) for node_type, name, _ in graph.iter_nodes()], 

409 [ 

410 # We only advertise that the order is topological and 

411 # deterministic, so this test is slightly over-specified; there 

412 # are other orders that are consistent with our guarantees. 

413 NodeKey(NodeType.DATASET_TYPE, "input_1"), 

414 NodeKey(NodeType.TASK_INIT, "a"), 

415 NodeKey(NodeType.DATASET_TYPE, "a_config"), 

416 NodeKey(NodeType.DATASET_TYPE, "schema"), 

417 NodeKey(NodeType.TASK_INIT, "b"), 

418 NodeKey(NodeType.DATASET_TYPE, "b_config"), 

419 NodeKey(NodeType.TASK, "a"), 

420 NodeKey(NodeType.DATASET_TYPE, "a_log"), 

421 NodeKey(NodeType.DATASET_TYPE, "a_metadata"), 

422 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

423 NodeKey(NodeType.TASK, "b"), 

424 NodeKey(NodeType.DATASET_TYPE, "b_log"), 

425 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

426 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

427 ], 

428 ) 

429 # Most users should only care that the tasks and dataset types are 

430 # topologically sorted. 

431 self.assertEqual(list(graph.tasks), ["a", "b"]) 

432 self.assertEqual( 

433 list(graph.dataset_types), 

434 [ 

435 "input_1", 

436 "a_config", 

437 "schema", 

438 "b_config", 

439 "a_log", 

440 "a_metadata", 

441 "intermediate_1", 

442 "b_log", 

443 "b_metadata", 

444 "output_1", 

445 ], 

446 ) 

447 # __str__ and __repr__ of course work on unsorted mapping views, too, 

448 # but the order of elements is then nondeterministic and hard to test. 

449 self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})") 

450 self.assertEqual( 

451 repr(self.graph.dataset_types), 

452 ( 

453 "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, " 

454 "intermediate_1, b_log, b_metadata, output_1})" 

455 ), 

456 ) 

457 

458 def check_make_xgraph( 

459 self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True 

460 ) -> None: 

461 """Check that the given graph exports as expected to networkx. 

462 

463 The given graph must be unchanged from the one defined in `setUp`, 

464 other than being resolved (if ``resolved=True``) or round-tripped 

465 through serialization without tasks being imported (if 

466 ``imported_and_configured=False``). 

467 """ 

468 xgraph = graph.make_xgraph() 

469 expected_edges = ( 

470 {edge.key for edge in graph.iter_edges()} 

471 | {edge.key for edge in graph.iter_edges(init=True)} 

472 | { 

473 (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME), 

474 (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME), 

475 } 

476 ) 

477 test_edges = set(xgraph.edges) 

478 self.assertEqual(test_edges, expected_edges) 

479 expected_nodes = { 

480 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node( 

481 "a", resolved, imported_and_configured=imported_and_configured 

482 ), 

483 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node( 

484 "a", resolved, imported_and_configured=imported_and_configured 

485 ), 

486 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node( 

487 "b", resolved, imported_and_configured=imported_and_configured 

488 ), 

489 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node( 

490 "b", resolved, imported_and_configured=imported_and_configured 

491 ), 

492 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

493 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

494 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

495 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

496 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

497 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

498 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

499 "schema", resolved, is_initial_query_constraint=False 

500 ), 

501 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

502 "input_1", resolved, is_initial_query_constraint=True 

503 ), 

504 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

505 "intermediate_1", resolved, is_initial_query_constraint=False 

506 ), 

507 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

508 "output_1", resolved, is_initial_query_constraint=False 

509 ), 

510 } 

511 test_nodes = dict(xgraph.nodes.items()) 

512 self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys())) 

513 for key, expected_node in expected_nodes.items(): 

514 test_node = test_nodes[key] 

515 self.assertEqual(expected_node, test_node, key) 

516 

517 def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

518 """Check that the given graph's init-only or runtime subset exports as 

519 expected to networkx. 

520 

521 The given graph must be unchanged from the one defined in `setUp`, 

522 other than being resolved (if ``resolved=True``). 

523 """ 

524 run_xgraph = graph.make_bipartite_xgraph() 

525 self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()}) 

526 self.assertEqual( 

527 dict(run_xgraph.nodes.items()), 

528 { 

529 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), 

530 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), 

531 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

532 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

533 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

534 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

535 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

536 "input_1", resolved, is_initial_query_constraint=True 

537 ), 

538 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

539 "intermediate_1", resolved, is_initial_query_constraint=False 

540 ), 

541 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

542 "output_1", resolved, is_initial_query_constraint=False 

543 ), 

544 }, 

545 ) 

546 init_xgraph = graph.make_bipartite_xgraph( 

547 init=True, 

548 ) 

549 self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)}) 

550 self.assertEqual( 

551 dict(init_xgraph.nodes.items()), 

552 { 

553 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), 

554 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), 

555 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

556 "schema", resolved, is_initial_query_constraint=False 

557 ), 

558 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

559 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

560 }, 

561 ) 

562 

563 def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

564 """Check that the given graph's task-only projection exports as 

565 expected to networkx. 

566 

567 The given graph must be unchanged from the one defined in `setUp`, 

568 other than being resolved (if ``resolved=True``). 

569 """ 

570 run_xgraph = graph.make_task_xgraph() 

571 self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))}) 

572 self.assertEqual( 

573 dict(run_xgraph.nodes.items()), 

574 { 

575 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), 

576 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), 

577 }, 

578 ) 

579 init_xgraph = graph.make_task_xgraph( 

580 init=True, 

581 ) 

582 self.assertEqual( 

583 set(init_xgraph.edges), 

584 {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))}, 

585 ) 

586 self.assertEqual( 

587 dict(init_xgraph.nodes.items()), 

588 { 

589 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), 

590 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), 

591 }, 

592 ) 

593 

594 def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: 

595 """Check that the given graph's dataset-type-only projection exports as 

596 expected to networkx. 

597 

598 The given graph must be unchanged from the one defined in `setUp`, 

599 other than being resolved (if ``resolved=True``). 

600 """ 

601 run_xgraph = graph.make_dataset_type_xgraph() 

602 self.assertEqual( 

603 set(run_xgraph.edges), 

604 { 

605 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")), 

606 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")), 

607 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")), 

608 ( 

609 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

610 NodeKey(NodeType.DATASET_TYPE, "output_1"), 

611 ), 

612 (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")), 

613 ( 

614 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), 

615 NodeKey(NodeType.DATASET_TYPE, "b_metadata"), 

616 ), 

617 }, 

618 ) 

619 self.assertEqual( 

620 dict(run_xgraph.nodes.items()), 

621 { 

622 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), 

623 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), 

624 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), 

625 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), 

626 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( 

627 "input_1", resolved, is_initial_query_constraint=True 

628 ), 

629 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( 

630 "intermediate_1", resolved, is_initial_query_constraint=False 

631 ), 

632 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( 

633 "output_1", resolved, is_initial_query_constraint=False 

634 ), 

635 }, 

636 ) 

637 init_xgraph = graph.make_dataset_type_xgraph(init=True) 

638 self.assertEqual( 

639 set(init_xgraph.edges), 

640 {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))}, 

641 ) 

642 self.assertEqual( 

643 dict(init_xgraph.nodes.items()), 

644 { 

645 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( 

646 "schema", resolved, is_initial_query_constraint=False 

647 ), 

648 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), 

649 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), 

650 }, 

651 ) 

652 

653 def get_expected_task_node( 

654 self, label: str, resolved: bool, imported_and_configured: bool = True 

655 ) -> dict[str, Any]: 

656 """Construct a networkx-export task node for comparison.""" 

657 result = self.get_expected_task_init_node( 

658 label, resolved, imported_and_configured=imported_and_configured 

659 ) 

660 if resolved: 

661 result["dimensions"] = self.dimensions.empty 

662 result["raw_dimensions"] = frozenset() 

663 return result 

664 

665 def get_expected_task_init_node( 

666 self, label: str, resolved: bool, imported_and_configured: bool = True 

667 ) -> dict[str, Any]: 

668 """Construct a networkx-export task init for comparison.""" 

669 result = { 

670 "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask", 

671 "bipartite": 1, 

672 } 

673 if imported_and_configured: 

674 result["task_class"] = DynamicTestPipelineTask 

675 result["config"] = getattr(self, f"{label}_config") 

676 return result 

677 

678 def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]: 

679 """Construct a networkx-export init-output config dataset type node for 

680 comparison. 

681 """ 

682 if not resolved: 

683 return {"bipartite": 0} 

684 else: 

685 return { 

686 "dataset_type": DatasetType( 

687 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), 

688 self.dimensions.empty, 

689 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

690 ), 

691 "is_initial_query_constraint": False, 

692 "is_prerequisite": False, 

693 "dimensions": self.dimensions.empty, 

694 "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

695 "bipartite": 0, 

696 } 

697 

698 def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]: 

699 """Construct a networkx-export output log dataset type node for 

700 comparison. 

701 """ 

702 if not resolved: 

703 return {"bipartite": 0} 

704 else: 

705 return { 

706 "dataset_type": DatasetType( 

707 acc.LOG_OUTPUT_TEMPLATE.format(label=label), 

708 self.dimensions.empty, 

709 acc.LOG_OUTPUT_STORAGE_CLASS, 

710 ), 

711 "is_initial_query_constraint": False, 

712 "is_prerequisite": False, 

713 "dimensions": self.dimensions.empty, 

714 "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS, 

715 "bipartite": 0, 

716 } 

717 

718 def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]: 

719 """Construct a networkx-export output metadata dataset type node for 

720 comparison. 

721 """ 

722 if not resolved: 

723 return {"bipartite": 0} 

724 else: 

725 return { 

726 "dataset_type": DatasetType( 

727 acc.METADATA_OUTPUT_TEMPLATE.format(label=label), 

728 self.dimensions.empty, 

729 acc.METADATA_OUTPUT_STORAGE_CLASS, 

730 ), 

731 "is_initial_query_constraint": False, 

732 "is_prerequisite": False, 

733 "dimensions": self.dimensions.empty, 

734 "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS, 

735 "bipartite": 0, 

736 } 

737 

738 def get_expected_connection_node( 

739 self, name: str, resolved: bool, *, is_initial_query_constraint: bool 

740 ) -> dict[str, Any]: 

741 """Construct a networkx-export dataset type node for comparison.""" 

742 if not resolved: 

743 return {"bipartite": 0} 

744 else: 

745 return { 

746 "dataset_type": DatasetType( 

747 name, 

748 self.dimensions.empty, 

749 get_mock_name("StructuredDataDict"), 

750 ), 

751 "is_initial_query_constraint": is_initial_query_constraint, 

752 "is_prerequisite": False, 

753 "dimensions": self.dimensions.empty, 

754 "storage_class_name": get_mock_name("StructuredDataDict"), 

755 "bipartite": 0, 

756 } 

757 

758 def test_construct_with_data_coordinate(self) -> None: 

759 """Test constructing a graph with a DataCoordinate. 

760 

761 Since this creates a graph with DimensionUniverse, all tasks added to 

762 it should have resolved dimensions, but not (yet) resolved dataset 

763 types. We use that to test a few other operations in that state. 

764 """ 

765 data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions) 

766 graph = PipelineGraph(data_id=data_id) 

767 self.assertEqual(graph.universe, self.dimensions) 

768 self.assertEqual(graph.data_id, data_id) 

769 graph.add_task("b1", DynamicTestPipelineTask, self.b_config) 

770 self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty) 

771 # Still can't group by dimensions, because the dataset types aren't 

772 # resolved. 

773 with self.assertRaises(UnresolvedGraphError): 

774 graph.group_by_dimensions() 

775 # Transferring a node from this graph to ``self.graph`` should 

776 # unresolve the dimensions. 

777 self.graph.add_task_nodes([graph.tasks["b1"]]) 

778 self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"]) 

779 self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions) 

780 # Do the opposite transfer, which should resolve dimensions. 

781 graph.add_task_nodes([self.graph.tasks["a"]]) 

782 self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"]) 

783 self.assertTrue(graph.tasks["a"].has_resolved_dimensions) 

784 

785 def test_group_by_dimensions(self) -> None: 

786 """Test PipelineGraph.group_by_dimensions.""" 

787 with self.assertRaises(UnresolvedGraphError): 

788 self.graph.group_by_dimensions() 

789 self.a_config.dimensions = ["visit"] 

790 self.a_config.outputs["output1"].dimensions = ["visit"] 

791 self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig( 

792 dataset_type_name="prereq_1", 

793 multiple=True, 

794 dimensions=["htm7"], 

795 is_calibration=True, 

796 ) 

797 self.b_config.dimensions = ["htm7"] 

798 self.b_config.inputs["input1"].dimensions = ["visit"] 

799 self.b_config.inputs["input1"].multiple = True 

800 self.b_config.outputs["output1"].dimensions = ["htm7"] 

801 self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config) 

802 self.graph.resolve(MockRegistry(self.dimensions, {})) 

803 visit_dims = self.dimensions.extract(["visit"]) 

804 htm7_dims = self.dimensions.extract(["htm7"]) 

805 expected = { 

806 self.dimensions.empty: ( 

807 {}, 

808 { 

809 "schema": self.graph.dataset_types["schema"], 

810 "input_1": self.graph.dataset_types["input_1"], 

811 "a_config": self.graph.dataset_types["a_config"], 

812 "b_config": self.graph.dataset_types["b_config"], 

813 }, 

814 ), 

815 visit_dims: ( 

816 {"a": self.graph.tasks["a"]}, 

817 { 

818 "a_log": self.graph.dataset_types["a_log"], 

819 "a_metadata": self.graph.dataset_types["a_metadata"], 

820 "intermediate_1": self.graph.dataset_types["intermediate_1"], 

821 }, 

822 ), 

823 htm7_dims: ( 

824 {"b": self.graph.tasks["b"]}, 

825 { 

826 "b_log": self.graph.dataset_types["b_log"], 

827 "b_metadata": self.graph.dataset_types["b_metadata"], 

828 "output_1": self.graph.dataset_types["output_1"], 

829 }, 

830 ), 

831 } 

832 self.assertEqual(self.graph.group_by_dimensions(), expected) 

833 expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"] 

834 self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected) 

835 

836 def test_add_and_remove(self) -> None: 

837 """Tests for adding and removing tasks and task subsets from a 

838 PipelineGraph. 

839 """ 

840 # Can't remove a task while it's still in a subset. 

841 with self.assertRaises(PipelineGraphError): 

842 self.graph.remove_tasks(["b"], drop_from_subsets=False) 

843 # ...unless you remove the subset. 

844 self.graph.remove_task_subset("only_b") 

845 self.assertFalse(self.graph.task_subsets) 

846 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False) 

847 self.assertFalse(referencing_subsets) 

848 self.assertEqual(self.graph.tasks.keys(), {"a"}) 

849 # Add that task back in. 

850 self.graph.add_task_nodes([b]) 

851 self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) 

852 # Add the subset back in. 

853 self.graph.add_task_subset("only_b", {"b"}) 

854 self.assertEqual(self.graph.task_subsets.keys(), {"only_b"}) 

855 # Resolve the graph's dataset types and task dimensions. 

856 self.graph.resolve(MockRegistry(self.dimensions, {})) 

857 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

858 self.assertTrue(self.graph.dataset_types.is_resolved("output_1")) 

859 self.assertTrue(self.graph.dataset_types.is_resolved("schema")) 

860 self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1")) 

861 # Remove the task while removing it from the subset automatically. This 

862 # should also unresolve (only) the referenced dataset types and drop 

863 # any datasets no longer attached to any task. 

864 self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) 

865 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True) 

866 self.assertEqual(referencing_subsets, {"only_b"}) 

867 self.assertEqual(self.graph.tasks.keys(), {"a"}) 

868 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

869 self.assertNotIn("output1", self.graph.dataset_types) 

870 self.assertFalse(self.graph.dataset_types.is_resolved("schema")) 

871 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) 

872 

873 def test_reconfigure(self) -> None: 

874 """Tests for PipelineGraph.reconfigure.""" 

875 self.graph.resolve(MockRegistry(self.dimensions, {})) 

876 self.b_config.outputs["output1"].storage_class = "TaskMetadata" 

877 with self.assertRaises(ValueError): 

878 # Can't check and assume together. 

879 self.graph.reconfigure_tasks( 

880 b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True 

881 ) 

882 # Check that graph is unchanged after error. 

883 self.check_base_accessors(self.graph) 

884 with self.assertRaises(EdgesChangedError): 

885 self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True) 

886 self.check_base_accessors(self.graph) 

887 # Make a change that does affect edges; this will unresolve most 

888 # dataset types. 

889 self.graph.reconfigure_tasks(b=self.b_config) 

890 self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) 

891 self.assertFalse(self.graph.dataset_types.is_resolved("output_1")) 

892 self.assertFalse(self.graph.dataset_types.is_resolved("schema")) 

893 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) 

894 # Resolving again will pick up the new storage class 

895 self.graph.resolve(MockRegistry(self.dimensions, {})) 

896 self.assertEqual( 

897 self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata") 

898 ) 

899 

900 

901def _have_example_storage_classes() -> bool: 

902 """Check whether some storage classes work as expected. 

903 

904 Given that these have registered converters, it shouldn't actually be 

905 necessary to import be able to those types in order to determine that 

906 they're convertible, but the storage class machinery is implemented such 

907 that types that can't be imported can't be converted, and while that's 

908 inconvenient here it's totally fine in non-testing scenarios where you only 

909 care about a storage class if you can actually use it. 

910 """ 

911 getter = StorageClassFactory().getStorageClass 

912 return ( 

913 getter("ArrowTable").can_convert(getter("ArrowAstropy")) 

914 and getter("ArrowAstropy").can_convert(getter("ArrowTable")) 

915 and getter("ArrowTable").can_convert(getter("DataFrame")) 

916 and getter("DataFrame").can_convert(getter("ArrowTable")) 

917 ) 

918 

919 

920class PipelineGraphResolveTestCase(unittest.TestCase): 

921 """More extensive tests for PipelineGraph.resolve and its primate helper 

922 methods. 

923 

924 These are in a separate TestCase because they utilize a different `setUp` 

925 from the rest of the `PipelineGraph` tests. 

926 """ 

927 

928 def setUp(self) -> None: 

929 self.a_config = DynamicTestPipelineTaskConfig() 

930 self.b_config = DynamicTestPipelineTaskConfig() 

931 self.dimensions = DimensionUniverse() 

932 self.maxDiff = None 

933 

934 def make_graph(self) -> PipelineGraph: 

935 graph = PipelineGraph() 

936 graph.add_task("a", DynamicTestPipelineTask, self.a_config) 

937 graph.add_task("b", DynamicTestPipelineTask, self.b_config) 

938 return graph 

939 

940 def test_prerequisite_inconsistency(self) -> None: 

941 """Test that we raise an exception when one edge defines a dataset type 

942 as a prerequisite and another does not. 

943 

944 This test will hopefully someday go away (along with 

945 `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation 

946 algorithm becomes more flexible. 

947 """ 

948 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

949 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") 

950 graph = self.make_graph() 

951 with self.assertRaises(ConnectionTypeConsistencyError): 

952 graph.resolve(MockRegistry(self.dimensions, {})) 

953 

954 def test_prerequisite_inconsistency_reversed(self) -> None: 

955 """Same as `test_prerequisite_inconsistency`, with the order the edges 

956 are added to the graph reversed. 

957 """ 

958 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") 

959 self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

960 graph = self.make_graph() 

961 with self.assertRaises(ConnectionTypeConsistencyError): 

962 graph.resolve(MockRegistry(self.dimensions, {})) 

963 

964 def test_prerequisite_output(self) -> None: 

965 """Test that we raise an exception when one edge defines a dataset type 

966 as a prerequisite but another defines it as an output. 

967 """ 

968 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") 

969 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

970 graph = self.make_graph() 

971 with self.assertRaises(ConnectionTypeConsistencyError): 

972 graph.resolve(MockRegistry(self.dimensions, {})) 

973 

974 def test_skypix_missing(self) -> None: 

975 """Test that we raise an exception when one edge uses the "skypix" 

976 dimension as a placeholder but the dataset type is not registered. 

977 """ 

978 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( 

979 dataset_type_name="d", dimensions={"skypix"} 

980 ) 

981 graph = self.make_graph() 

982 with self.assertRaises(MissingDatasetTypeError): 

983 graph.resolve(MockRegistry(self.dimensions, {})) 

984 

985 def test_skypix_inconsistent(self) -> None: 

986 """Test that we raise an exception when one edge uses the "skypix" 

987 dimension as a placeholder but the rest of the dimensions are 

988 inconsistent with the registered dataset type. 

989 """ 

990 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( 

991 dataset_type_name="d", dimensions={"skypix", "visit"} 

992 ) 

993 graph = self.make_graph() 

994 with self.assertRaises(IncompatibleDatasetTypeError): 

995 graph.resolve( 

996 MockRegistry( 

997 self.dimensions, 

998 { 

999 "d": DatasetType( 

1000 "d", 

1001 dimensions=self.dimensions.extract(["htm7"]), 

1002 storageClass="StructuredDataDict", 

1003 ) 

1004 }, 

1005 ) 

1006 ) 

1007 with self.assertRaises(IncompatibleDatasetTypeError): 

1008 graph.resolve( 

1009 MockRegistry( 

1010 self.dimensions, 

1011 { 

1012 "d": DatasetType( 

1013 "d", 

1014 dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]), 

1015 storageClass="StructuredDataDict", 

1016 ) 

1017 }, 

1018 ) 

1019 ) 

1020 

1021 def test_duplicate_outputs(self) -> None: 

1022 """Test that we raise an exception when a dataset type node would have 

1023 two write edges. 

1024 """ 

1025 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

1026 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") 

1027 graph = self.make_graph() 

1028 with self.assertRaises(DuplicateOutputError): 

1029 graph.resolve(MockRegistry(self.dimensions, {})) 

1030 

1031 def test_component_of_unregistered_parent(self) -> None: 

1032 """Test that we raise an exception when a component dataset type's 

1033 parent is not registered. 

1034 """ 

1035 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") 

1036 graph = self.make_graph() 

1037 with self.assertRaises(MissingDatasetTypeError): 

1038 graph.resolve(MockRegistry(self.dimensions, {})) 

1039 

1040 def test_undefined_component(self) -> None: 

1041 """Test that we raise an exception when a component dataset type's 

1042 parent is registered, but its storage class does not have that 

1043 component. 

1044 """ 

1045 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") 

1046 graph = self.make_graph() 

1047 with self.assertRaises(IncompatibleDatasetTypeError): 

1048 graph.resolve( 

1049 MockRegistry( 

1050 self.dimensions, 

1051 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1052 ) 

1053 ) 

1054 

1055 @unittest.skipUnless( 

1056 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1057 ) 

1058 def test_bad_component_storage_class(self) -> None: 

1059 """Test that we raise an exception when a component dataset type's 

1060 parent is registered, but does not have that component. 

1061 """ 

1062 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1063 dataset_type_name="d.schema", storage_class="StructuredDataDict" 

1064 ) 

1065 graph = self.make_graph() 

1066 with self.assertRaises(IncompatibleDatasetTypeError): 

1067 graph.resolve( 

1068 MockRegistry( 

1069 self.dimensions, 

1070 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, 

1071 ) 

1072 ) 

1073 

1074 def test_input_storage_class_incompatible_with_registry(self) -> None: 

1075 """Test that we raise an exception when an input connection's storage 

1076 class is incompatible with the registry definition. 

1077 """ 

1078 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1079 dataset_type_name="d", storage_class="StructuredDataList" 

1080 ) 

1081 graph = self.make_graph() 

1082 with self.assertRaises(IncompatibleDatasetTypeError): 

1083 graph.resolve( 

1084 MockRegistry( 

1085 self.dimensions, 

1086 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1087 ) 

1088 ) 

1089 

1090 def test_output_storage_class_incompatible_with_registry(self) -> None: 

1091 """Test that we raise an exception when an output connection's storage 

1092 class is incompatible with the registry definition. 

1093 """ 

1094 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1095 dataset_type_name="d", storage_class="StructuredDataList" 

1096 ) 

1097 graph = self.make_graph() 

1098 with self.assertRaises(IncompatibleDatasetTypeError): 

1099 graph.resolve( 

1100 MockRegistry( 

1101 self.dimensions, 

1102 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, 

1103 ) 

1104 ) 

1105 

1106 def test_input_storage_class_incompatible_with_output(self) -> None: 

1107 """Test that we raise an exception when an input connection's storage 

1108 class is incompatible with the storage class of the output connection. 

1109 """ 

1110 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1111 dataset_type_name="d", storage_class="StructuredDataDict" 

1112 ) 

1113 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1114 dataset_type_name="d", storage_class="StructuredDataList" 

1115 ) 

1116 graph = self.make_graph() 

1117 with self.assertRaises(IncompatibleDatasetTypeError): 

1118 graph.resolve(MockRegistry(self.dimensions, {})) 

1119 

1120 def test_ambiguous_storage_class(self) -> None: 

1121 """Test that we raise an exception when two input connections define 

1122 the same dataset with different storage classes (even compatible ones) 

1123 and there is no output connection or registry definition to take 

1124 precedence. 

1125 """ 

1126 self.a_config.inputs["i"] = DynamicConnectionConfig( 

1127 dataset_type_name="d", storage_class="StructuredDataDict" 

1128 ) 

1129 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1130 dataset_type_name="d", storage_class="StructuredDataList" 

1131 ) 

1132 graph = self.make_graph() 

1133 with self.assertRaises(MissingDatasetTypeError): 

1134 graph.resolve(MockRegistry(self.dimensions, {})) 

1135 

1136 @unittest.skipUnless( 

1137 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1138 ) 

1139 def test_inputs_compatible_with_registry(self) -> None: 

1140 """Test successful resolution of a dataset type where input edges have 

1141 different but compatible storage classes and the dataset type is 

1142 already registered. 

1143 """ 

1144 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") 

1145 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1146 dataset_type_name="d", storage_class="ArrowAstropy" 

1147 ) 

1148 graph = self.make_graph() 

1149 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) 

1150 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) 

1151 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) 

1152 a_i = graph.tasks["a"].inputs["i"] 

1153 b_i = graph.tasks["b"].inputs["i"] 

1154 self.assertEqual( 

1155 a_i.adapt_dataset_type(dataset_type), 

1156 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), 

1157 ) 

1158 self.assertEqual( 

1159 b_i.adapt_dataset_type(dataset_type), 

1160 dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), 

1161 ) 

1162 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1163 ref = DatasetRef(dataset_type, data_id, run="r") 

1164 a_ref = a_i.adapt_dataset_ref(ref) 

1165 b_ref = b_i.adapt_dataset_ref(ref) 

1166 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) 

1167 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) 

1168 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1169 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1170 

1171 @unittest.skipUnless( 

1172 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1173 ) 

1174 def test_output_compatible_with_registry(self) -> None: 

1175 """Test successful resolution of a dataset type where an output edge 

1176 has a different but compatible storage class from the dataset type 

1177 already registered. 

1178 """ 

1179 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1180 dataset_type_name="d", storage_class="ArrowTable" 

1181 ) 

1182 graph = self.make_graph() 

1183 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) 

1184 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) 

1185 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) 

1186 a_o = graph.tasks["a"].outputs["o"] 

1187 self.assertEqual( 

1188 a_o.adapt_dataset_type(dataset_type), 

1189 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), 

1190 ) 

1191 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1192 ref = DatasetRef(dataset_type, data_id, run="r") 

1193 a_ref = a_o.adapt_dataset_ref(ref) 

1194 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) 

1195 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1196 

1197 @unittest.skipUnless( 

1198 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1199 ) 

1200 def test_inputs_compatible_with_output(self) -> None: 

1201 """Test successful resolution of a dataset type where an input edge has 

1202 a different but compatible storage class from the output edge, and 

1203 the dataset type is not registered. 

1204 """ 

1205 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1206 dataset_type_name="d", storage_class="ArrowTable" 

1207 ) 

1208 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1209 dataset_type_name="d", storage_class="ArrowAstropy" 

1210 ) 

1211 graph = self.make_graph() 

1212 a_o = graph.tasks["a"].outputs["o"] 

1213 b_i = graph.tasks["b"].inputs["i"] 

1214 graph.resolve(MockRegistry(self.dimensions, {})) 

1215 self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable")) 

1216 self.assertEqual( 

1217 a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type), 

1218 graph.dataset_types["d"].dataset_type, 

1219 ) 

1220 self.assertEqual( 

1221 b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type), 

1222 graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), 

1223 ) 

1224 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1225 ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r") 

1226 a_ref = a_o.adapt_dataset_ref(ref) 

1227 b_ref = b_i.adapt_dataset_ref(ref) 

1228 self.assertEqual(a_ref, ref) 

1229 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) 

1230 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1231 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1232 

1233 @unittest.skipUnless( 

1234 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1235 ) 

1236 def test_component_resolved_by_input(self) -> None: 

1237 """Test successful resolution of a component dataset type due to 

1238 another input referencing the parent dataset type. 

1239 """ 

1240 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") 

1241 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1242 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1243 ) 

1244 graph = self.make_graph() 

1245 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1246 graph.resolve(MockRegistry(self.dimensions, {})) 

1247 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1248 a_i = graph.tasks["a"].inputs["i"] 

1249 b_i = graph.tasks["b"].inputs["i"] 

1250 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1251 self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type) 

1252 self.assertEqual( 

1253 b_i.adapt_dataset_type(parent_dataset_type), 

1254 parent_dataset_type.makeComponentDatasetType("schema"), 

1255 ) 

1256 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1257 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1258 a_ref = a_i.adapt_dataset_ref(ref) 

1259 b_ref = b_i.adapt_dataset_ref(ref) 

1260 self.assertEqual(a_ref, ref) 

1261 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1262 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1263 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1264 

1265 @unittest.skipUnless( 

1266 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1267 ) 

1268 def test_component_resolved_by_output(self) -> None: 

1269 """Test successful resolution of a component dataset type due to 

1270 an output connection referencing the parent dataset type. 

1271 """ 

1272 self.a_config.outputs["o"] = DynamicConnectionConfig( 

1273 dataset_type_name="d", storage_class="ArrowTable" 

1274 ) 

1275 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1276 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1277 ) 

1278 graph = self.make_graph() 

1279 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1280 graph.resolve(MockRegistry(self.dimensions, {})) 

1281 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1282 a_o = graph.tasks["a"].outputs["o"] 

1283 b_i = graph.tasks["b"].inputs["i"] 

1284 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1285 self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type) 

1286 self.assertEqual( 

1287 b_i.adapt_dataset_type(parent_dataset_type), 

1288 parent_dataset_type.makeComponentDatasetType("schema"), 

1289 ) 

1290 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1291 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1292 a_ref = a_o.adapt_dataset_ref(ref) 

1293 b_ref = b_i.adapt_dataset_ref(ref) 

1294 self.assertEqual(a_ref, ref) 

1295 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1296 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) 

1297 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1298 

1299 @unittest.skipUnless( 

1300 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." 

1301 ) 

1302 def test_component_resolved_by_registry(self) -> None: 

1303 """Test successful resolution of a component dataset type due to 

1304 the parent dataset type already being registered. 

1305 """ 

1306 self.b_config.inputs["i"] = DynamicConnectionConfig( 

1307 dataset_type_name="d.schema", storage_class="ArrowSchema" 

1308 ) 

1309 graph = self.make_graph() 

1310 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) 

1311 graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type})) 

1312 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) 

1313 b_i = graph.tasks["b"].inputs["i"] 

1314 self.assertEqual(b_i.dataset_type_name, "d.schema") 

1315 self.assertEqual( 

1316 b_i.adapt_dataset_type(parent_dataset_type), 

1317 parent_dataset_type.makeComponentDatasetType("schema"), 

1318 ) 

1319 data_id = DataCoordinate.makeEmpty(self.dimensions) 

1320 ref = DatasetRef(parent_dataset_type, data_id, run="r") 

1321 b_ref = b_i.adapt_dataset_ref(ref) 

1322 self.assertEqual(b_ref, ref.makeComponentRef("schema")) 

1323 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) 

1324 

1325 

1326if __name__ == "__main__": 

1327 lsst.utils.tests.init() 

1328 unittest.main()