Coverage for tests/test_pipeline_graph.py: 15%
511 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of things related to the GraphBuilder class."""
24import copy
25import io
26import logging
27import unittest
28from typing import Any
30import lsst.pipe.base.automatic_connection_constants as acc
31import lsst.utils.tests
32from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory
33from lsst.daf.butler.registry import MissingDatasetTypeError
34from lsst.pipe.base.pipeline_graph import (
35 ConnectionTypeConsistencyError,
36 DuplicateOutputError,
37 Edge,
38 EdgesChangedError,
39 IncompatibleDatasetTypeError,
40 NodeKey,
41 NodeType,
42 PipelineGraph,
43 PipelineGraphError,
44 TaskImportMode,
45 UnresolvedGraphError,
46)
47from lsst.pipe.base.tests.mocks import (
48 DynamicConnectionConfig,
49 DynamicTestPipelineTask,
50 DynamicTestPipelineTaskConfig,
51 get_mock_name,
52)
54_LOG = logging.getLogger(__name__)
57class MockRegistry:
58 """A test-utility stand-in for lsst.daf.butler.Registry that just knows
59 how to get dataset types.
60 """
62 def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None:
63 self.dimensions = dimensions
64 self._dataset_types = dataset_types
66 def getDatasetType(self, name: str) -> DatasetType:
67 try:
68 return self._dataset_types[name]
69 except KeyError:
70 raise MissingDatasetTypeError(name) from None
73class PipelineGraphTestCase(unittest.TestCase):
74 """Tests for the `PipelineGraph` class.
76 Tests for `PipelineGraph.resolve` are mostly in
77 `PipelineGraphResolveTestCase` later in this file.
78 """
80 def setUp(self) -> None:
81 # Simple test pipeline has two tasks, 'a' and 'b', with dataset types
82 # 'input', 'intermediate', and 'output'. There are no dimensions on
83 # any of those. We add tasks in reverse order to better test sorting.
84 # There is one labeled task subset, 'only_b', with just 'b' in it.
85 # We copy the configs so the originals (the instance attributes) can
86 # be modified and reused after the ones passed in to the graph are
87 # frozen.
88 self.description = "A pipeline for PipelineGraph unit tests."
89 self.graph = PipelineGraph()
90 self.graph.description = self.description
91 self.b_config = DynamicTestPipelineTaskConfig()
92 self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
93 self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
94 self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1")
95 self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config))
96 self.a_config = DynamicTestPipelineTaskConfig()
97 self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
98 self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1")
99 self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
100 self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config))
101 self.graph.add_task_subset("only_b", ["b"])
102 self.subset_description = "A subset with only task B in it."
103 self.graph.task_subsets["only_b"].description = self.subset_description
104 self.dimensions = DimensionUniverse()
105 self.maxDiff = None
107 def test_unresolved_accessors(self) -> None:
108 """Test attribute accessors, iteration, and simple methods on a graph
109 that has not had `PipelineGraph.resolve` called on it.
110 """
111 self.check_base_accessors(self.graph)
112 self.assertEqual(
113 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)"
114 )
116 def test_sorting(self) -> None:
117 """Test sort methods on PipelineGraph."""
118 self.assertFalse(self.graph.has_been_sorted)
119 self.assertFalse(self.graph.is_sorted)
120 self.graph.sort()
121 self.check_sorted(self.graph)
123 def test_unresolved_xgraph_export(self) -> None:
124 """Test exporting an unresolved PipelineGraph to networkx in various
125 ways.
126 """
127 self.check_make_xgraph(self.graph, resolved=False)
128 self.check_make_bipartite_xgraph(self.graph, resolved=False)
129 self.check_make_task_xgraph(self.graph, resolved=False)
130 self.check_make_dataset_type_xgraph(self.graph, resolved=False)
132 def test_unresolved_stream_io(self) -> None:
133 """Test round-tripping an unresolved PipelineGraph through in-memory
134 serialization.
135 """
136 stream = io.BytesIO()
137 self.graph._write_stream(stream)
138 stream.seek(0)
139 roundtripped = PipelineGraph._read_stream(stream)
140 self.check_make_xgraph(roundtripped, resolved=False)
142 def test_unresolved_file_io(self) -> None:
143 """Test round-tripping an unresolved PipelineGraph through file
144 serialization.
145 """
146 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
147 self.graph._write_uri(filename)
148 roundtripped = PipelineGraph._read_uri(filename)
149 self.check_make_xgraph(roundtripped, resolved=False)
151 def test_unresolved_deferred_import_io(self) -> None:
152 """Test round-tripping an unresolved PipelineGraph through
153 serialization, without immediately importing tasks on read.
154 """
155 stream = io.BytesIO()
156 self.graph._write_stream(stream)
157 stream.seek(0)
158 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
159 self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False)
160 # Check that we can still resolve the graph without importing tasks.
161 roundtripped.resolve(MockRegistry(self.dimensions, {}))
162 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
163 roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES)
164 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
166 def test_resolved_accessors(self) -> None:
167 """Test attribute accessors, iteration, and simple methods on a graph
168 that has had `PipelineGraph.resolve` called on it.
170 This includes the accessors available on unresolved graphs as well as
171 new ones, and we expect the resolved graph to be sorted as well.
172 """
173 self.graph.resolve(MockRegistry(self.dimensions, {}))
174 self.check_base_accessors(self.graph)
175 self.check_sorted(self.graph)
176 self.assertEqual(
177 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})"
178 )
179 self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty)
180 self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})")
181 self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1"))
182 self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty)
183 self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict")
184 self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict")
186 def test_resolved_xgraph_export(self) -> None:
187 """Test exporting a resolved PipelineGraph to networkx in various
188 ways.
189 """
190 self.graph.resolve(MockRegistry(self.dimensions, {}))
191 self.check_make_xgraph(self.graph, resolved=True)
192 self.check_make_bipartite_xgraph(self.graph, resolved=True)
193 self.check_make_task_xgraph(self.graph, resolved=True)
194 self.check_make_dataset_type_xgraph(self.graph, resolved=True)
196 def test_resolved_stream_io(self) -> None:
197 """Test round-tripping a resolved PipelineGraph through in-memory
198 serialization.
199 """
200 self.graph.resolve(MockRegistry(self.dimensions, {}))
201 stream = io.BytesIO()
202 self.graph._write_stream(stream)
203 stream.seek(0)
204 roundtripped = PipelineGraph._read_stream(stream)
205 self.check_make_xgraph(roundtripped, resolved=True)
207 def test_resolved_file_io(self) -> None:
208 """Test round-tripping a resolved PipelineGraph through file
209 serialization.
210 """
211 self.graph.resolve(MockRegistry(self.dimensions, {}))
212 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
213 self.graph._write_uri(filename)
214 roundtripped = PipelineGraph._read_uri(filename)
215 self.check_make_xgraph(roundtripped, resolved=True)
217 def test_resolved_deferred_import_io(self) -> None:
218 """Test round-tripping a resolved PipelineGraph through serialization,
219 without immediately importing tasks on read.
220 """
221 self.graph.resolve(MockRegistry(self.dimensions, {}))
222 stream = io.BytesIO()
223 self.graph._write_stream(stream)
224 stream.seek(0)
225 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
226 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
227 roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES)
228 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
230 def test_unresolved_copies(self) -> None:
231 """Test making copies of an unresolved PipelineGraph."""
232 copy1 = self.graph.copy()
233 self.assertIsNot(copy1, self.graph)
234 self.check_make_xgraph(copy1, resolved=False)
235 copy2 = copy.copy(self.graph)
236 self.assertIsNot(copy2, self.graph)
237 self.check_make_xgraph(copy2, resolved=False)
238 copy3 = copy.deepcopy(self.graph)
239 self.assertIsNot(copy3, self.graph)
240 self.check_make_xgraph(copy3, resolved=False)
242 def test_resolved_copies(self) -> None:
243 """Test making copies of a resolved PipelineGraph."""
244 self.graph.resolve(MockRegistry(self.dimensions, {}))
245 copy1 = self.graph.copy()
246 self.assertIsNot(copy1, self.graph)
247 self.check_make_xgraph(copy1, resolved=True)
248 copy2 = copy.copy(self.graph)
249 self.assertIsNot(copy2, self.graph)
250 self.check_make_xgraph(copy2, resolved=True)
251 copy3 = copy.deepcopy(self.graph)
252 self.assertIsNot(copy3, self.graph)
253 self.check_make_xgraph(copy3, resolved=True)
255 def check_base_accessors(self, graph: PipelineGraph) -> None:
256 """Run parameterized tests that check attribute access, iteration, and
257 simple methods.
259 The given graph must be unchanged from the one defined in `setUp`,
260 other than sorting.
261 """
262 self.assertEqual(graph.description, self.description)
263 self.assertEqual(graph.tasks.keys(), {"a", "b"})
264 self.assertEqual(
265 graph.dataset_types.keys(),
266 {
267 "schema",
268 "input_1",
269 "intermediate_1",
270 "output_1",
271 "a_config",
272 "a_log",
273 "a_metadata",
274 "b_config",
275 "b_log",
276 "b_metadata",
277 },
278 )
279 self.assertEqual(graph.task_subsets.keys(), {"only_b"})
280 self.assertEqual(
281 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)},
282 {
283 (
284 NodeKey(NodeType.DATASET_TYPE, "input_1"),
285 NodeKey(NodeType.TASK, "a"),
286 "input_1 -> a (input1)",
287 ),
288 (
289 NodeKey(NodeType.TASK, "a"),
290 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
291 "a -> intermediate_1 (output1)",
292 ),
293 (
294 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
295 NodeKey(NodeType.TASK, "b"),
296 "intermediate_1 -> b (input1)",
297 ),
298 (
299 NodeKey(NodeType.TASK, "b"),
300 NodeKey(NodeType.DATASET_TYPE, "output_1"),
301 "b -> output_1 (output1)",
302 ),
303 (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"),
304 (
305 NodeKey(NodeType.TASK, "a"),
306 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
307 "a -> a_metadata (_metadata)",
308 ),
309 (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"),
310 (
311 NodeKey(NodeType.TASK, "b"),
312 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
313 "b -> b_metadata (_metadata)",
314 ),
315 },
316 )
317 self.assertEqual(
318 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)},
319 {
320 (
321 NodeKey(NodeType.TASK_INIT, "a"),
322 NodeKey(NodeType.DATASET_TYPE, "schema"),
323 "a -> schema (out_schema)",
324 ),
325 (
326 NodeKey(NodeType.DATASET_TYPE, "schema"),
327 NodeKey(NodeType.TASK_INIT, "b"),
328 "schema -> b (in_schema)",
329 ),
330 (
331 NodeKey(NodeType.TASK_INIT, "a"),
332 NodeKey(NodeType.DATASET_TYPE, "a_config"),
333 "a -> a_config (_config)",
334 ),
335 (
336 NodeKey(NodeType.TASK_INIT, "b"),
337 NodeKey(NodeType.DATASET_TYPE, "b_config"),
338 "b -> b_config (_config)",
339 ),
340 },
341 )
342 self.assertEqual(
343 {(node_type, name) for node_type, name, _ in graph.iter_nodes()},
344 {
345 NodeKey(NodeType.TASK, "a"),
346 NodeKey(NodeType.TASK, "b"),
347 NodeKey(NodeType.TASK_INIT, "a"),
348 NodeKey(NodeType.TASK_INIT, "b"),
349 NodeKey(NodeType.DATASET_TYPE, "schema"),
350 NodeKey(NodeType.DATASET_TYPE, "input_1"),
351 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
352 NodeKey(NodeType.DATASET_TYPE, "output_1"),
353 NodeKey(NodeType.DATASET_TYPE, "a_config"),
354 NodeKey(NodeType.DATASET_TYPE, "a_log"),
355 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
356 NodeKey(NodeType.DATASET_TYPE, "b_config"),
357 NodeKey(NodeType.DATASET_TYPE, "b_log"),
358 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
359 },
360 )
361 self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"})
362 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"})
363 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"})
364 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set())
365 self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"})
366 self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"})
367 self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set())
369 self.assertIsNone(graph.producing_edge_of("input_1"))
370 self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a")
371 self.assertEqual(graph.producing_edge_of("output_1").task_label, "b")
372 self.assertIsNone(graph.producer_of("input_1"))
373 self.assertEqual(graph.producer_of("intermediate_1").label, "a")
374 self.assertEqual(graph.producer_of("output_1").label, "b")
376 self.assertEqual(graph.inputs_of("a").keys(), {"input_1"})
377 self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"})
378 self.assertEqual(graph.inputs_of("a", init=True).keys(), set())
379 self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"})
380 self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"})
381 self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"})
382 self.assertEqual(
383 graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"}
384 )
385 self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"})
386 self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"})
387 self.assertEqual(
388 graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"}
389 )
390 self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"})
391 self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set())
393 self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks="))
394 self.assertEqual(
395 repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}"
396 )
398 def check_sorted(self, graph: PipelineGraph) -> None:
399 """Run a battery of tests on a PipelineGraph that must be
400 deterministically sorted.
402 The given graph must be unchanged from the one defined in `setUp`,
403 other than sorting.
404 """
405 self.assertTrue(graph.has_been_sorted)
406 self.assertTrue(graph.is_sorted)
407 self.assertEqual(
408 [(node_type, name) for node_type, name, _ in graph.iter_nodes()],
409 [
410 # We only advertise that the order is topological and
411 # deterministic, so this test is slightly over-specified; there
412 # are other orders that are consistent with our guarantees.
413 NodeKey(NodeType.DATASET_TYPE, "input_1"),
414 NodeKey(NodeType.TASK_INIT, "a"),
415 NodeKey(NodeType.DATASET_TYPE, "a_config"),
416 NodeKey(NodeType.DATASET_TYPE, "schema"),
417 NodeKey(NodeType.TASK_INIT, "b"),
418 NodeKey(NodeType.DATASET_TYPE, "b_config"),
419 NodeKey(NodeType.TASK, "a"),
420 NodeKey(NodeType.DATASET_TYPE, "a_log"),
421 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
422 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
423 NodeKey(NodeType.TASK, "b"),
424 NodeKey(NodeType.DATASET_TYPE, "b_log"),
425 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
426 NodeKey(NodeType.DATASET_TYPE, "output_1"),
427 ],
428 )
429 # Most users should only care that the tasks and dataset types are
430 # topologically sorted.
431 self.assertEqual(list(graph.tasks), ["a", "b"])
432 self.assertEqual(
433 list(graph.dataset_types),
434 [
435 "input_1",
436 "a_config",
437 "schema",
438 "b_config",
439 "a_log",
440 "a_metadata",
441 "intermediate_1",
442 "b_log",
443 "b_metadata",
444 "output_1",
445 ],
446 )
447 # __str__ and __repr__ of course work on unsorted mapping views, too,
448 # but the order of elements is then nondeterministic and hard to test.
449 self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})")
450 self.assertEqual(
451 repr(self.graph.dataset_types),
452 (
453 "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, "
454 "intermediate_1, b_log, b_metadata, output_1})"
455 ),
456 )
458 def check_make_xgraph(
459 self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True
460 ) -> None:
461 """Check that the given graph exports as expected to networkx.
463 The given graph must be unchanged from the one defined in `setUp`,
464 other than being resolved (if ``resolved=True``) or round-tripped
465 through serialization without tasks being imported (if
466 ``imported_and_configured=False``).
467 """
468 xgraph = graph.make_xgraph()
469 expected_edges = (
470 {edge.key for edge in graph.iter_edges()}
471 | {edge.key for edge in graph.iter_edges(init=True)}
472 | {
473 (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME),
474 (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME),
475 }
476 )
477 test_edges = set(xgraph.edges)
478 self.assertEqual(test_edges, expected_edges)
479 expected_nodes = {
480 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node(
481 "a", resolved, imported_and_configured=imported_and_configured
482 ),
483 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node(
484 "a", resolved, imported_and_configured=imported_and_configured
485 ),
486 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node(
487 "b", resolved, imported_and_configured=imported_and_configured
488 ),
489 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node(
490 "b", resolved, imported_and_configured=imported_and_configured
491 ),
492 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
493 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
494 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
495 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
496 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
497 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
498 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
499 "schema", resolved, is_initial_query_constraint=False
500 ),
501 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
502 "input_1", resolved, is_initial_query_constraint=True
503 ),
504 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
505 "intermediate_1", resolved, is_initial_query_constraint=False
506 ),
507 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
508 "output_1", resolved, is_initial_query_constraint=False
509 ),
510 }
511 test_nodes = dict(xgraph.nodes.items())
512 self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys()))
513 for key, expected_node in expected_nodes.items():
514 test_node = test_nodes[key]
515 self.assertEqual(expected_node, test_node, key)
517 def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
518 """Check that the given graph's init-only or runtime subset exports as
519 expected to networkx.
521 The given graph must be unchanged from the one defined in `setUp`,
522 other than being resolved (if ``resolved=True``).
523 """
524 run_xgraph = graph.make_bipartite_xgraph()
525 self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()})
526 self.assertEqual(
527 dict(run_xgraph.nodes.items()),
528 {
529 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
530 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
531 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
532 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
533 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
534 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
535 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
536 "input_1", resolved, is_initial_query_constraint=True
537 ),
538 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
539 "intermediate_1", resolved, is_initial_query_constraint=False
540 ),
541 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
542 "output_1", resolved, is_initial_query_constraint=False
543 ),
544 },
545 )
546 init_xgraph = graph.make_bipartite_xgraph(
547 init=True,
548 )
549 self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)})
550 self.assertEqual(
551 dict(init_xgraph.nodes.items()),
552 {
553 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
554 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
555 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
556 "schema", resolved, is_initial_query_constraint=False
557 ),
558 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
559 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
560 },
561 )
563 def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
564 """Check that the given graph's task-only projection exports as
565 expected to networkx.
567 The given graph must be unchanged from the one defined in `setUp`,
568 other than being resolved (if ``resolved=True``).
569 """
570 run_xgraph = graph.make_task_xgraph()
571 self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))})
572 self.assertEqual(
573 dict(run_xgraph.nodes.items()),
574 {
575 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
576 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
577 },
578 )
579 init_xgraph = graph.make_task_xgraph(
580 init=True,
581 )
582 self.assertEqual(
583 set(init_xgraph.edges),
584 {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))},
585 )
586 self.assertEqual(
587 dict(init_xgraph.nodes.items()),
588 {
589 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
590 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
591 },
592 )
594 def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
595 """Check that the given graph's dataset-type-only projection exports as
596 expected to networkx.
598 The given graph must be unchanged from the one defined in `setUp`,
599 other than being resolved (if ``resolved=True``).
600 """
601 run_xgraph = graph.make_dataset_type_xgraph()
602 self.assertEqual(
603 set(run_xgraph.edges),
604 {
605 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")),
606 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")),
607 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")),
608 (
609 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
610 NodeKey(NodeType.DATASET_TYPE, "output_1"),
611 ),
612 (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")),
613 (
614 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
615 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
616 ),
617 },
618 )
619 self.assertEqual(
620 dict(run_xgraph.nodes.items()),
621 {
622 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
623 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
624 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
625 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
626 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
627 "input_1", resolved, is_initial_query_constraint=True
628 ),
629 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
630 "intermediate_1", resolved, is_initial_query_constraint=False
631 ),
632 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
633 "output_1", resolved, is_initial_query_constraint=False
634 ),
635 },
636 )
637 init_xgraph = graph.make_dataset_type_xgraph(init=True)
638 self.assertEqual(
639 set(init_xgraph.edges),
640 {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))},
641 )
642 self.assertEqual(
643 dict(init_xgraph.nodes.items()),
644 {
645 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
646 "schema", resolved, is_initial_query_constraint=False
647 ),
648 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
649 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
650 },
651 )
653 def get_expected_task_node(
654 self, label: str, resolved: bool, imported_and_configured: bool = True
655 ) -> dict[str, Any]:
656 """Construct a networkx-export task node for comparison."""
657 result = self.get_expected_task_init_node(
658 label, resolved, imported_and_configured=imported_and_configured
659 )
660 if resolved:
661 result["dimensions"] = self.dimensions.empty
662 result["raw_dimensions"] = frozenset()
663 return result
665 def get_expected_task_init_node(
666 self, label: str, resolved: bool, imported_and_configured: bool = True
667 ) -> dict[str, Any]:
668 """Construct a networkx-export task init for comparison."""
669 result = {
670 "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask",
671 "bipartite": 1,
672 }
673 if imported_and_configured:
674 result["task_class"] = DynamicTestPipelineTask
675 result["config"] = getattr(self, f"{label}_config")
676 return result
678 def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]:
679 """Construct a networkx-export init-output config dataset type node for
680 comparison.
681 """
682 if not resolved:
683 return {"bipartite": 0}
684 else:
685 return {
686 "dataset_type": DatasetType(
687 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
688 self.dimensions.empty,
689 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
690 ),
691 "is_initial_query_constraint": False,
692 "is_prerequisite": False,
693 "dimensions": self.dimensions.empty,
694 "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
695 "bipartite": 0,
696 }
698 def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]:
699 """Construct a networkx-export output log dataset type node for
700 comparison.
701 """
702 if not resolved:
703 return {"bipartite": 0}
704 else:
705 return {
706 "dataset_type": DatasetType(
707 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
708 self.dimensions.empty,
709 acc.LOG_OUTPUT_STORAGE_CLASS,
710 ),
711 "is_initial_query_constraint": False,
712 "is_prerequisite": False,
713 "dimensions": self.dimensions.empty,
714 "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS,
715 "bipartite": 0,
716 }
718 def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]:
719 """Construct a networkx-export output metadata dataset type node for
720 comparison.
721 """
722 if not resolved:
723 return {"bipartite": 0}
724 else:
725 return {
726 "dataset_type": DatasetType(
727 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
728 self.dimensions.empty,
729 acc.METADATA_OUTPUT_STORAGE_CLASS,
730 ),
731 "is_initial_query_constraint": False,
732 "is_prerequisite": False,
733 "dimensions": self.dimensions.empty,
734 "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS,
735 "bipartite": 0,
736 }
738 def get_expected_connection_node(
739 self, name: str, resolved: bool, *, is_initial_query_constraint: bool
740 ) -> dict[str, Any]:
741 """Construct a networkx-export dataset type node for comparison."""
742 if not resolved:
743 return {"bipartite": 0}
744 else:
745 return {
746 "dataset_type": DatasetType(
747 name,
748 self.dimensions.empty,
749 get_mock_name("StructuredDataDict"),
750 ),
751 "is_initial_query_constraint": is_initial_query_constraint,
752 "is_prerequisite": False,
753 "dimensions": self.dimensions.empty,
754 "storage_class_name": get_mock_name("StructuredDataDict"),
755 "bipartite": 0,
756 }
758 def test_construct_with_data_coordinate(self) -> None:
759 """Test constructing a graph with a DataCoordinate.
761 Since this creates a graph with DimensionUniverse, all tasks added to
762 it should have resolved dimensions, but not (yet) resolved dataset
763 types. We use that to test a few other operations in that state.
764 """
765 data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions)
766 graph = PipelineGraph(data_id=data_id)
767 self.assertEqual(graph.universe, self.dimensions)
768 self.assertEqual(graph.data_id, data_id)
769 graph.add_task("b1", DynamicTestPipelineTask, self.b_config)
770 self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty)
771 # Still can't group by dimensions, because the dataset types aren't
772 # resolved.
773 with self.assertRaises(UnresolvedGraphError):
774 graph.group_by_dimensions()
775 # Transferring a node from this graph to ``self.graph`` should
776 # unresolve the dimensions.
777 self.graph.add_task_nodes([graph.tasks["b1"]])
778 self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"])
779 self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions)
780 # Do the opposite transfer, which should resolve dimensions.
781 graph.add_task_nodes([self.graph.tasks["a"]])
782 self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"])
783 self.assertTrue(graph.tasks["a"].has_resolved_dimensions)
785 def test_group_by_dimensions(self) -> None:
786 """Test PipelineGraph.group_by_dimensions."""
787 with self.assertRaises(UnresolvedGraphError):
788 self.graph.group_by_dimensions()
789 self.a_config.dimensions = ["visit"]
790 self.a_config.outputs["output1"].dimensions = ["visit"]
791 self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig(
792 dataset_type_name="prereq_1",
793 multiple=True,
794 dimensions=["htm7"],
795 is_calibration=True,
796 )
797 self.b_config.dimensions = ["htm7"]
798 self.b_config.inputs["input1"].dimensions = ["visit"]
799 self.b_config.inputs["input1"].multiple = True
800 self.b_config.outputs["output1"].dimensions = ["htm7"]
801 self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config)
802 self.graph.resolve(MockRegistry(self.dimensions, {}))
803 visit_dims = self.dimensions.extract(["visit"])
804 htm7_dims = self.dimensions.extract(["htm7"])
805 expected = {
806 self.dimensions.empty: (
807 {},
808 {
809 "schema": self.graph.dataset_types["schema"],
810 "input_1": self.graph.dataset_types["input_1"],
811 "a_config": self.graph.dataset_types["a_config"],
812 "b_config": self.graph.dataset_types["b_config"],
813 },
814 ),
815 visit_dims: (
816 {"a": self.graph.tasks["a"]},
817 {
818 "a_log": self.graph.dataset_types["a_log"],
819 "a_metadata": self.graph.dataset_types["a_metadata"],
820 "intermediate_1": self.graph.dataset_types["intermediate_1"],
821 },
822 ),
823 htm7_dims: (
824 {"b": self.graph.tasks["b"]},
825 {
826 "b_log": self.graph.dataset_types["b_log"],
827 "b_metadata": self.graph.dataset_types["b_metadata"],
828 "output_1": self.graph.dataset_types["output_1"],
829 },
830 ),
831 }
832 self.assertEqual(self.graph.group_by_dimensions(), expected)
833 expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"]
834 self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected)
836 def test_add_and_remove(self) -> None:
837 """Tests for adding and removing tasks and task subsets from a
838 PipelineGraph.
839 """
840 # Can't remove a task while it's still in a subset.
841 with self.assertRaises(PipelineGraphError):
842 self.graph.remove_tasks(["b"], drop_from_subsets=False)
843 # ...unless you remove the subset.
844 self.graph.remove_task_subset("only_b")
845 self.assertFalse(self.graph.task_subsets)
846 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False)
847 self.assertFalse(referencing_subsets)
848 self.assertEqual(self.graph.tasks.keys(), {"a"})
849 # Add that task back in.
850 self.graph.add_task_nodes([b])
851 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
852 # Add the subset back in.
853 self.graph.add_task_subset("only_b", {"b"})
854 self.assertEqual(self.graph.task_subsets.keys(), {"only_b"})
855 # Resolve the graph's dataset types and task dimensions.
856 self.graph.resolve(MockRegistry(self.dimensions, {}))
857 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
858 self.assertTrue(self.graph.dataset_types.is_resolved("output_1"))
859 self.assertTrue(self.graph.dataset_types.is_resolved("schema"))
860 self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1"))
861 # Remove the task while removing it from the subset automatically. This
862 # should also unresolve (only) the referenced dataset types and drop
863 # any datasets no longer attached to any task.
864 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
865 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True)
866 self.assertEqual(referencing_subsets, {"only_b"})
867 self.assertEqual(self.graph.tasks.keys(), {"a"})
868 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
869 self.assertNotIn("output1", self.graph.dataset_types)
870 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
871 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
873 def test_reconfigure(self) -> None:
874 """Tests for PipelineGraph.reconfigure."""
875 self.graph.resolve(MockRegistry(self.dimensions, {}))
876 self.b_config.outputs["output1"].storage_class = "TaskMetadata"
877 with self.assertRaises(ValueError):
878 # Can't check and assume together.
879 self.graph.reconfigure_tasks(
880 b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True
881 )
882 # Check that graph is unchanged after error.
883 self.check_base_accessors(self.graph)
884 with self.assertRaises(EdgesChangedError):
885 self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True)
886 self.check_base_accessors(self.graph)
887 # Make a change that does affect edges; this will unresolve most
888 # dataset types.
889 self.graph.reconfigure_tasks(b=self.b_config)
890 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
891 self.assertFalse(self.graph.dataset_types.is_resolved("output_1"))
892 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
893 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
894 # Resolving again will pick up the new storage class
895 self.graph.resolve(MockRegistry(self.dimensions, {}))
896 self.assertEqual(
897 self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata")
898 )
901def _have_example_storage_classes() -> bool:
902 """Check whether some storage classes work as expected.
904 Given that these have registered converters, it shouldn't actually be
905 necessary to import be able to those types in order to determine that
906 they're convertible, but the storage class machinery is implemented such
907 that types that can't be imported can't be converted, and while that's
908 inconvenient here it's totally fine in non-testing scenarios where you only
909 care about a storage class if you can actually use it.
910 """
911 getter = StorageClassFactory().getStorageClass
912 return (
913 getter("ArrowTable").can_convert(getter("ArrowAstropy"))
914 and getter("ArrowAstropy").can_convert(getter("ArrowTable"))
915 and getter("ArrowTable").can_convert(getter("DataFrame"))
916 and getter("DataFrame").can_convert(getter("ArrowTable"))
917 )
920class PipelineGraphResolveTestCase(unittest.TestCase):
921 """More extensive tests for PipelineGraph.resolve and its primate helper
922 methods.
924 These are in a separate TestCase because they utilize a different `setUp`
925 from the rest of the `PipelineGraph` tests.
926 """
928 def setUp(self) -> None:
929 self.a_config = DynamicTestPipelineTaskConfig()
930 self.b_config = DynamicTestPipelineTaskConfig()
931 self.dimensions = DimensionUniverse()
932 self.maxDiff = None
934 def make_graph(self) -> PipelineGraph:
935 graph = PipelineGraph()
936 graph.add_task("a", DynamicTestPipelineTask, self.a_config)
937 graph.add_task("b", DynamicTestPipelineTask, self.b_config)
938 return graph
940 def test_prerequisite_inconsistency(self) -> None:
941 """Test that we raise an exception when one edge defines a dataset type
942 as a prerequisite and another does not.
944 This test will hopefully someday go away (along with
945 `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation
946 algorithm becomes more flexible.
947 """
948 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
949 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
950 graph = self.make_graph()
951 with self.assertRaises(ConnectionTypeConsistencyError):
952 graph.resolve(MockRegistry(self.dimensions, {}))
954 def test_prerequisite_inconsistency_reversed(self) -> None:
955 """Same as `test_prerequisite_inconsistency`, with the order the edges
956 are added to the graph reversed.
957 """
958 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
959 self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
960 graph = self.make_graph()
961 with self.assertRaises(ConnectionTypeConsistencyError):
962 graph.resolve(MockRegistry(self.dimensions, {}))
964 def test_prerequisite_output(self) -> None:
965 """Test that we raise an exception when one edge defines a dataset type
966 as a prerequisite but another defines it as an output.
967 """
968 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
969 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
970 graph = self.make_graph()
971 with self.assertRaises(ConnectionTypeConsistencyError):
972 graph.resolve(MockRegistry(self.dimensions, {}))
974 def test_skypix_missing(self) -> None:
975 """Test that we raise an exception when one edge uses the "skypix"
976 dimension as a placeholder but the dataset type is not registered.
977 """
978 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
979 dataset_type_name="d", dimensions={"skypix"}
980 )
981 graph = self.make_graph()
982 with self.assertRaises(MissingDatasetTypeError):
983 graph.resolve(MockRegistry(self.dimensions, {}))
985 def test_skypix_inconsistent(self) -> None:
986 """Test that we raise an exception when one edge uses the "skypix"
987 dimension as a placeholder but the rest of the dimensions are
988 inconsistent with the registered dataset type.
989 """
990 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
991 dataset_type_name="d", dimensions={"skypix", "visit"}
992 )
993 graph = self.make_graph()
994 with self.assertRaises(IncompatibleDatasetTypeError):
995 graph.resolve(
996 MockRegistry(
997 self.dimensions,
998 {
999 "d": DatasetType(
1000 "d",
1001 dimensions=self.dimensions.extract(["htm7"]),
1002 storageClass="StructuredDataDict",
1003 )
1004 },
1005 )
1006 )
1007 with self.assertRaises(IncompatibleDatasetTypeError):
1008 graph.resolve(
1009 MockRegistry(
1010 self.dimensions,
1011 {
1012 "d": DatasetType(
1013 "d",
1014 dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]),
1015 storageClass="StructuredDataDict",
1016 )
1017 },
1018 )
1019 )
1021 def test_duplicate_outputs(self) -> None:
1022 """Test that we raise an exception when a dataset type node would have
1023 two write edges.
1024 """
1025 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1026 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1027 graph = self.make_graph()
1028 with self.assertRaises(DuplicateOutputError):
1029 graph.resolve(MockRegistry(self.dimensions, {}))
1031 def test_component_of_unregistered_parent(self) -> None:
1032 """Test that we raise an exception when a component dataset type's
1033 parent is not registered.
1034 """
1035 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1036 graph = self.make_graph()
1037 with self.assertRaises(MissingDatasetTypeError):
1038 graph.resolve(MockRegistry(self.dimensions, {}))
1040 def test_undefined_component(self) -> None:
1041 """Test that we raise an exception when a component dataset type's
1042 parent is registered, but its storage class does not have that
1043 component.
1044 """
1045 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1046 graph = self.make_graph()
1047 with self.assertRaises(IncompatibleDatasetTypeError):
1048 graph.resolve(
1049 MockRegistry(
1050 self.dimensions,
1051 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1052 )
1053 )
1055 @unittest.skipUnless(
1056 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1057 )
1058 def test_bad_component_storage_class(self) -> None:
1059 """Test that we raise an exception when a component dataset type's
1060 parent is registered, but does not have that component.
1061 """
1062 self.a_config.inputs["i"] = DynamicConnectionConfig(
1063 dataset_type_name="d.schema", storage_class="StructuredDataDict"
1064 )
1065 graph = self.make_graph()
1066 with self.assertRaises(IncompatibleDatasetTypeError):
1067 graph.resolve(
1068 MockRegistry(
1069 self.dimensions,
1070 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))},
1071 )
1072 )
1074 def test_input_storage_class_incompatible_with_registry(self) -> None:
1075 """Test that we raise an exception when an input connection's storage
1076 class is incompatible with the registry definition.
1077 """
1078 self.a_config.inputs["i"] = DynamicConnectionConfig(
1079 dataset_type_name="d", storage_class="StructuredDataList"
1080 )
1081 graph = self.make_graph()
1082 with self.assertRaises(IncompatibleDatasetTypeError):
1083 graph.resolve(
1084 MockRegistry(
1085 self.dimensions,
1086 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1087 )
1088 )
1090 def test_output_storage_class_incompatible_with_registry(self) -> None:
1091 """Test that we raise an exception when an output connection's storage
1092 class is incompatible with the registry definition.
1093 """
1094 self.a_config.outputs["o"] = DynamicConnectionConfig(
1095 dataset_type_name="d", storage_class="StructuredDataList"
1096 )
1097 graph = self.make_graph()
1098 with self.assertRaises(IncompatibleDatasetTypeError):
1099 graph.resolve(
1100 MockRegistry(
1101 self.dimensions,
1102 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1103 )
1104 )
1106 def test_input_storage_class_incompatible_with_output(self) -> None:
1107 """Test that we raise an exception when an input connection's storage
1108 class is incompatible with the storage class of the output connection.
1109 """
1110 self.a_config.outputs["o"] = DynamicConnectionConfig(
1111 dataset_type_name="d", storage_class="StructuredDataDict"
1112 )
1113 self.b_config.inputs["i"] = DynamicConnectionConfig(
1114 dataset_type_name="d", storage_class="StructuredDataList"
1115 )
1116 graph = self.make_graph()
1117 with self.assertRaises(IncompatibleDatasetTypeError):
1118 graph.resolve(MockRegistry(self.dimensions, {}))
1120 def test_ambiguous_storage_class(self) -> None:
1121 """Test that we raise an exception when two input connections define
1122 the same dataset with different storage classes (even compatible ones)
1123 and there is no output connection or registry definition to take
1124 precedence.
1125 """
1126 self.a_config.inputs["i"] = DynamicConnectionConfig(
1127 dataset_type_name="d", storage_class="StructuredDataDict"
1128 )
1129 self.b_config.inputs["i"] = DynamicConnectionConfig(
1130 dataset_type_name="d", storage_class="StructuredDataList"
1131 )
1132 graph = self.make_graph()
1133 with self.assertRaises(MissingDatasetTypeError):
1134 graph.resolve(MockRegistry(self.dimensions, {}))
1136 @unittest.skipUnless(
1137 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1138 )
1139 def test_inputs_compatible_with_registry(self) -> None:
1140 """Test successful resolution of a dataset type where input edges have
1141 different but compatible storage classes and the dataset type is
1142 already registered.
1143 """
1144 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1145 self.b_config.inputs["i"] = DynamicConnectionConfig(
1146 dataset_type_name="d", storage_class="ArrowAstropy"
1147 )
1148 graph = self.make_graph()
1149 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1150 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1151 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1152 a_i = graph.tasks["a"].inputs["i"]
1153 b_i = graph.tasks["b"].inputs["i"]
1154 self.assertEqual(
1155 a_i.adapt_dataset_type(dataset_type),
1156 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1157 )
1158 self.assertEqual(
1159 b_i.adapt_dataset_type(dataset_type),
1160 dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1161 )
1162 data_id = DataCoordinate.makeEmpty(self.dimensions)
1163 ref = DatasetRef(dataset_type, data_id, run="r")
1164 a_ref = a_i.adapt_dataset_ref(ref)
1165 b_ref = b_i.adapt_dataset_ref(ref)
1166 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1167 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1168 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1169 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1171 @unittest.skipUnless(
1172 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1173 )
1174 def test_output_compatible_with_registry(self) -> None:
1175 """Test successful resolution of a dataset type where an output edge
1176 has a different but compatible storage class from the dataset type
1177 already registered.
1178 """
1179 self.a_config.outputs["o"] = DynamicConnectionConfig(
1180 dataset_type_name="d", storage_class="ArrowTable"
1181 )
1182 graph = self.make_graph()
1183 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1184 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1185 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1186 a_o = graph.tasks["a"].outputs["o"]
1187 self.assertEqual(
1188 a_o.adapt_dataset_type(dataset_type),
1189 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1190 )
1191 data_id = DataCoordinate.makeEmpty(self.dimensions)
1192 ref = DatasetRef(dataset_type, data_id, run="r")
1193 a_ref = a_o.adapt_dataset_ref(ref)
1194 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1195 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1197 @unittest.skipUnless(
1198 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1199 )
1200 def test_inputs_compatible_with_output(self) -> None:
1201 """Test successful resolution of a dataset type where an input edge has
1202 a different but compatible storage class from the output edge, and
1203 the dataset type is not registered.
1204 """
1205 self.a_config.outputs["o"] = DynamicConnectionConfig(
1206 dataset_type_name="d", storage_class="ArrowTable"
1207 )
1208 self.b_config.inputs["i"] = DynamicConnectionConfig(
1209 dataset_type_name="d", storage_class="ArrowAstropy"
1210 )
1211 graph = self.make_graph()
1212 a_o = graph.tasks["a"].outputs["o"]
1213 b_i = graph.tasks["b"].inputs["i"]
1214 graph.resolve(MockRegistry(self.dimensions, {}))
1215 self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable"))
1216 self.assertEqual(
1217 a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1218 graph.dataset_types["d"].dataset_type,
1219 )
1220 self.assertEqual(
1221 b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1222 graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1223 )
1224 data_id = DataCoordinate.makeEmpty(self.dimensions)
1225 ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r")
1226 a_ref = a_o.adapt_dataset_ref(ref)
1227 b_ref = b_i.adapt_dataset_ref(ref)
1228 self.assertEqual(a_ref, ref)
1229 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1230 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1231 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1233 @unittest.skipUnless(
1234 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1235 )
1236 def test_component_resolved_by_input(self) -> None:
1237 """Test successful resolution of a component dataset type due to
1238 another input referencing the parent dataset type.
1239 """
1240 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1241 self.b_config.inputs["i"] = DynamicConnectionConfig(
1242 dataset_type_name="d.schema", storage_class="ArrowSchema"
1243 )
1244 graph = self.make_graph()
1245 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1246 graph.resolve(MockRegistry(self.dimensions, {}))
1247 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1248 a_i = graph.tasks["a"].inputs["i"]
1249 b_i = graph.tasks["b"].inputs["i"]
1250 self.assertEqual(b_i.dataset_type_name, "d.schema")
1251 self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1252 self.assertEqual(
1253 b_i.adapt_dataset_type(parent_dataset_type),
1254 parent_dataset_type.makeComponentDatasetType("schema"),
1255 )
1256 data_id = DataCoordinate.makeEmpty(self.dimensions)
1257 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1258 a_ref = a_i.adapt_dataset_ref(ref)
1259 b_ref = b_i.adapt_dataset_ref(ref)
1260 self.assertEqual(a_ref, ref)
1261 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1262 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1263 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1265 @unittest.skipUnless(
1266 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1267 )
1268 def test_component_resolved_by_output(self) -> None:
1269 """Test successful resolution of a component dataset type due to
1270 an output connection referencing the parent dataset type.
1271 """
1272 self.a_config.outputs["o"] = DynamicConnectionConfig(
1273 dataset_type_name="d", storage_class="ArrowTable"
1274 )
1275 self.b_config.inputs["i"] = DynamicConnectionConfig(
1276 dataset_type_name="d.schema", storage_class="ArrowSchema"
1277 )
1278 graph = self.make_graph()
1279 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1280 graph.resolve(MockRegistry(self.dimensions, {}))
1281 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1282 a_o = graph.tasks["a"].outputs["o"]
1283 b_i = graph.tasks["b"].inputs["i"]
1284 self.assertEqual(b_i.dataset_type_name, "d.schema")
1285 self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1286 self.assertEqual(
1287 b_i.adapt_dataset_type(parent_dataset_type),
1288 parent_dataset_type.makeComponentDatasetType("schema"),
1289 )
1290 data_id = DataCoordinate.makeEmpty(self.dimensions)
1291 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1292 a_ref = a_o.adapt_dataset_ref(ref)
1293 b_ref = b_i.adapt_dataset_ref(ref)
1294 self.assertEqual(a_ref, ref)
1295 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1296 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1297 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1299 @unittest.skipUnless(
1300 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1301 )
1302 def test_component_resolved_by_registry(self) -> None:
1303 """Test successful resolution of a component dataset type due to
1304 the parent dataset type already being registered.
1305 """
1306 self.b_config.inputs["i"] = DynamicConnectionConfig(
1307 dataset_type_name="d.schema", storage_class="ArrowSchema"
1308 )
1309 graph = self.make_graph()
1310 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1311 graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type}))
1312 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1313 b_i = graph.tasks["b"].inputs["i"]
1314 self.assertEqual(b_i.dataset_type_name, "d.schema")
1315 self.assertEqual(
1316 b_i.adapt_dataset_type(parent_dataset_type),
1317 parent_dataset_type.makeComponentDatasetType("schema"),
1318 )
1319 data_id = DataCoordinate.makeEmpty(self.dimensions)
1320 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1321 b_ref = b_i.adapt_dataset_ref(ref)
1322 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1323 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1326if __name__ == "__main__":
1327 lsst.utils.tests.init()
1328 unittest.main()