Coverage for tests / test_pipeline_graph.py: 12%
811 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:44 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests of things related to the PipelineGraph class."""
30import copy
31import io
32import itertools
33import logging
34import pickle
35import textwrap
36import unittest
37from typing import Any
39import lsst.pipe.base.automatic_connection_constants as acc
40import lsst.utils.tests
41from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory
42from lsst.daf.butler.registry import MissingDatasetTypeError
43from lsst.pipe.base.pipeline_graph import (
44 ConnectionTypeConsistencyError,
45 DuplicateOutputError,
46 Edge,
47 EdgesChangedError,
48 IncompatibleDatasetTypeError,
49 InvalidExpressionError,
50 InvalidStepsError,
51 NodeKey,
52 NodeType,
53 PipelineGraph,
54 PipelineGraphError,
55 TaskImportMode,
56 UnresolvedGraphError,
57 visualization,
58)
59from lsst.pipe.base.tests.mocks import (
60 DynamicConnectionConfig,
61 DynamicTestPipelineTask,
62 DynamicTestPipelineTaskConfig,
63 get_mock_name,
64)
66_LOG = logging.getLogger(__name__)
69class MockRegistry:
70 """A test-utility stand-in for lsst.daf.butler.Registry that just knows
71 how to get dataset types.
72 """
74 def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None:
75 self.dimensions = dimensions
76 self._dataset_types = dataset_types
78 def getDatasetType(self, name: str) -> DatasetType:
79 try:
80 return self._dataset_types[name]
81 except KeyError:
82 raise MissingDatasetTypeError(name) from None
85class PipelineGraphTestCase(unittest.TestCase):
86 """Tests for the `PipelineGraph` class.
88 Tests for `PipelineGraph.resolve` are mostly in
89 `PipelineGraphResolveTestCase` later in this file.
90 """
92 def setUp(self) -> None:
93 # Simple test pipeline has two tasks, 'a' and 'b', with dataset types
94 # 'input', 'intermediate', and 'output'. There are no dimensions on
95 # any of those. We add tasks in reverse order to better test sorting.
96 # There is one labeled task subset, 'only_b', with just 'b' in it.
97 # We copy the configs so the originals (the instance attributes) can
98 # be modified and reused after the ones passed in to the graph are
99 # frozen.
100 self.description = "A pipeline for PipelineGraph unit tests."
101 self.graph = PipelineGraph()
102 self.graph.description = self.description
103 self.b_config = DynamicTestPipelineTaskConfig()
104 self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
105 self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
106 self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1")
107 self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config))
108 self.a_config = DynamicTestPipelineTaskConfig()
109 self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
110 self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1")
111 self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
112 self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config))
113 self.graph.add_task_subset("only_b", ["b"])
114 self.subset_description = "A subset with only task B in it."
115 self.graph.task_subsets["only_b"].description = self.subset_description
116 self.dimensions = DimensionUniverse()
117 self.maxDiff = None
119 def test_unresolved_accessors(self) -> None:
120 """Test attribute accessors, iteration, and simple methods on a graph
121 that has not had `PipelineGraph.resolve` called on it.
122 """
123 self.check_base_accessors(self.graph)
124 self.assertEqual(
125 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)"
126 )
127 with self.assertRaises(UnresolvedGraphError):
128 self.graph.packages_dataset_type
129 with self.assertRaises(UnresolvedGraphError):
130 self.graph.instantiate_tasks()
132 def test_sorting(self) -> None:
133 """Test sort methods on PipelineGraph."""
134 self.assertFalse(self.graph.has_been_sorted)
135 self.graph.sort()
136 self.check_sorted(self.graph)
138 def test_unresolved_xgraph_export(self) -> None:
139 """Test exporting an unresolved PipelineGraph to networkx in various
140 ways.
141 """
142 self.check_make_xgraph(self.graph, resolved=False)
143 self.check_make_bipartite_xgraph(self.graph, resolved=False)
144 self.check_make_task_xgraph(self.graph, resolved=False)
145 self.check_make_dataset_type_xgraph(self.graph, resolved=False)
147 def test_unresolved_stream_io(self) -> None:
148 """Test round-tripping an unresolved PipelineGraph through in-memory
149 serialization.
150 """
151 stream = io.BytesIO()
152 self.graph._write_stream(stream)
153 stream.seek(0)
154 roundtripped = PipelineGraph._read_stream(stream)
155 self.check_make_xgraph(roundtripped, resolved=False)
157 def test_unresolved_file_io(self) -> None:
158 """Test round-tripping an unresolved PipelineGraph through file
159 serialization.
160 """
161 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
162 self.graph._write_uri(filename)
163 roundtripped = PipelineGraph._read_uri(filename)
164 self.check_make_xgraph(roundtripped, resolved=False)
166 def test_unresolved_pickle(self) -> None:
167 """Test that unresolved PipelineGraph objects can be pickled."""
168 self.check_make_xgraph(pickle.loads(pickle.dumps(self.graph)), resolved=False)
170 def test_unresolved_deferred_import_io(self) -> None:
171 """Test round-tripping an unresolved PipelineGraph through
172 serialization, without immediately importing tasks on read.
173 """
174 stream = io.BytesIO()
175 self.graph._write_stream(stream)
176 stream.seek(0)
177 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
178 self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False)
179 self.check_make_xgraph(
180 pickle.loads(pickle.dumps(roundtripped)), resolved=False, imported_and_configured=False
181 )
182 # Check that we can still resolve the graph without importing tasks.
183 roundtripped.resolve(MockRegistry(self.dimensions, {}))
184 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
185 roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES)
186 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
188 def test_resolved_accessors(self) -> None:
189 """Test attribute accessors, iteration, and simple methods on a graph
190 that has had `PipelineGraph.resolve` called on it.
192 This includes the accessors available on unresolved graphs as well as
193 new ones, and we expect the resolved graph to be sorted as well.
194 """
195 self.graph.resolve(MockRegistry(self.dimensions, {}))
196 self.check_base_accessors(self.graph)
197 self.check_sorted(self.graph)
198 self.assertEqual(
199 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})"
200 )
201 self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty)
202 self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})")
203 self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1"))
204 self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty)
205 self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict")
206 self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict")
207 self.assertEqual(self.graph.packages_dataset_type.name, acc.PACKAGES_INIT_OUTPUT_NAME)
209 def test_resolved_xgraph_export(self) -> None:
210 """Test exporting a resolved PipelineGraph to networkx in various
211 ways.
212 """
213 self.graph.resolve(MockRegistry(self.dimensions, {}))
214 self.check_make_xgraph(self.graph, resolved=True)
215 self.check_make_bipartite_xgraph(self.graph, resolved=True)
216 self.check_make_task_xgraph(self.graph, resolved=True)
217 self.check_make_dataset_type_xgraph(self.graph, resolved=True)
219 def test_resolved_stream_io(self) -> None:
220 """Test round-tripping a resolved PipelineGraph through in-memory
221 serialization.
222 """
223 # Add some steps to make sure those round-trip, too.
224 self.graph.add_task_subset("step1", {"a"})
225 self.graph.add_task_subset("step2", {"b"})
226 self.graph.steps = ["step1", "step2"]
227 self.graph.resolve(MockRegistry(self.dimensions, {}))
228 stream = io.BytesIO()
229 self.graph._write_stream(stream)
230 stream.seek(0)
231 roundtripped = PipelineGraph._read_stream(stream)
232 self.check_make_xgraph(roundtripped, resolved=True)
233 self.assertEqual(roundtripped.steps, self.graph.steps)
235 def test_resolved_file_io(self) -> None:
236 """Test round-tripping a resolved PipelineGraph through file
237 serialization.
238 """
239 self.graph.resolve(MockRegistry(self.dimensions, {}))
240 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
241 self.graph._write_uri(filename)
242 roundtripped = PipelineGraph._read_uri(filename)
243 self.check_make_xgraph(roundtripped, resolved=True)
245 def test_resolved_pickle(self) -> None:
246 """Test that resolved PipelineGraph objects can be pickled."""
247 self.graph.resolve(MockRegistry(self.dimensions, {}))
248 self.check_make_xgraph(pickle.loads(pickle.dumps(self.graph)), resolved=True)
250 def test_resolved_deferred_import_io(self) -> None:
251 """Test round-tripping a resolved PipelineGraph through serialization,
252 without immediately importing tasks on read.
253 """
254 self.graph.resolve(MockRegistry(self.dimensions, {}))
255 stream = io.BytesIO()
256 self.graph._write_stream(stream)
257 stream.seek(0)
258 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
259 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
260 self.check_make_xgraph(
261 pickle.loads(pickle.dumps(roundtripped)), resolved=True, imported_and_configured=False
262 )
263 roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES)
264 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
266 def test_unresolved_copies(self) -> None:
267 """Test making copies of an unresolved PipelineGraph."""
268 copy1 = self.graph.copy()
269 self.assertIsNot(copy1, self.graph)
270 self.check_make_xgraph(copy1, resolved=False)
271 copy2 = copy.copy(self.graph)
272 self.assertIsNot(copy2, self.graph)
273 self.check_make_xgraph(copy2, resolved=False)
274 copy3 = copy.deepcopy(self.graph)
275 self.assertIsNot(copy3, self.graph)
276 self.check_make_xgraph(copy3, resolved=False)
278 def test_resolved_copies(self) -> None:
279 """Test making copies of a resolved PipelineGraph."""
280 self.graph.resolve(MockRegistry(self.dimensions, {}))
281 copy1 = self.graph.copy()
282 self.assertIsNot(copy1, self.graph)
283 self.check_make_xgraph(copy1, resolved=True)
284 copy2 = copy.copy(self.graph)
285 self.assertIsNot(copy2, self.graph)
286 self.check_make_xgraph(copy2, resolved=True)
287 copy3 = copy.deepcopy(self.graph)
288 self.assertIsNot(copy3, self.graph)
289 self.check_make_xgraph(copy3, resolved=True)
291 def test_valid_steps(self) -> None:
292 """Test step definitions that are valid."""
293 self.graph.add_task_subset("step1", {"a"})
294 self.graph.add_task_subset("step2", {"b"})
295 with self.assertRaises(InvalidStepsError):
296 # Can't call this yet; no steps.
297 self.graph.get_task_step("step1")
298 with self.assertRaises(InvalidStepsError):
299 # Can't call this yet either.
300 self.graph.steps.get_dimensions("step1")
301 self.graph.steps = ["step1", "step2"]
302 with self.assertRaises(UnresolvedGraphError):
303 # Still can't call it yet; steps not verified.
304 self.graph.get_task_step("step1")
305 with self.assertRaises(UnresolvedGraphError):
306 # Can't call this yet either.
307 self.graph.steps.get_dimensions("step1")
308 self.assertEqual(list(self.graph.steps), ["step1", "step2"])
309 self.graph.resolve(MockRegistry(self.dimensions, {}))
310 self.assertEqual(str(self.graph.steps), "['step1', 'step2']")
311 self.assertEqual(list(self.graph.steps), ["step1", "step2"])
312 self.assertTrue(self.graph.task_subsets["step1"].is_step)
313 self.assertTrue(self.graph.task_subsets["step2"].is_step)
314 self.assertEqual(self.graph.task_subsets["step1"].dimensions, self.dimensions.empty)
315 self.assertEqual(self.graph.task_subsets["step2"].dimensions, self.dimensions.empty)
316 self.assertEqual(self.graph.get_task_step("a"), "step1")
317 self.assertEqual(self.graph.get_task_step("b"), "step2")
318 # Check that step verification round-trips through serialization.
319 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
320 self.graph._write_uri(filename)
321 roundtripped = PipelineGraph._read_uri(filename)
322 self.assertTrue(roundtripped.is_fully_resolved)
324 def test_valid_steps_resolved_graph(self) -> None:
325 """Test step definitions that are valid, adding them to a graph that
326 has already been resolved.
327 """
328 self.graph.add_task_subset("step1", {"a"})
329 self.graph.add_task_subset("step2", {"b"})
330 self.graph.resolve(MockRegistry(self.dimensions, {}))
331 # Can't call these yet; no steps.
332 with self.assertRaises(InvalidStepsError):
333 self.graph.get_task_step("a")
334 self.graph.steps = ["step1"]
335 self.graph.steps.append("step2", dimensions=self.dimensions.empty)
336 with self.assertRaises(UnresolvedGraphError):
337 # Still can't call it yet; steps not verified.
338 self.graph.get_task_step("a")
339 self.graph.resolve(MockRegistry(self.dimensions, {}))
340 # After we resolve again everything works.
341 self.assertEqual(list(self.graph.steps), ["step1", "step2"])
342 self.assertEqual(list(self.graph.steps), ["step1", "step2"])
343 self.assertTrue(self.graph.task_subsets["step1"].is_step)
344 self.assertTrue(self.graph.task_subsets["step2"].is_step)
345 self.assertEqual(self.graph.task_subsets["step1"].dimensions, self.dimensions.empty)
346 self.assertEqual(self.graph.task_subsets["step2"].dimensions, self.dimensions.empty)
347 self.assertEqual(self.graph.get_task_step("a"), "step1")
348 self.assertEqual(self.graph.get_task_step("b"), "step2")
350 def test_valid_step_exposure_visit_substitution(self) -> None:
351 """Test that step sharding dimensions permit an 'exposure'-based task
352 in a 'visit'-sharded step.
353 """
354 c_config = DynamicTestPipelineTaskConfig()
355 c_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
356 c_config.outputs["output2"] = DynamicConnectionConfig(
357 dataset_type_name="output_2", dimensions=["exposure", "detector"]
358 )
359 c_config.dimensions = ["exposure"]
360 self.graph.add_task("c", DynamicTestPipelineTask, c_config)
361 self.graph.add_task_subset("step1", {"a", "b"})
362 self.graph.add_task_subset("step2", {"c"})
363 self.graph.steps = ["step1", "step2"]
364 self.graph.task_subsets["step2"].dimensions = {"visit"}
365 self.graph.resolve(MockRegistry(self.dimensions, {}))
366 self.assertEqual(list(self.graph.steps), ["step1", "step2"])
367 self.assertEqual(self.graph.get_task_step("a"), "step1")
368 self.assertEqual(self.graph.get_task_step("b"), "step1")
369 self.assertEqual(self.graph.get_task_step("c"), "step2")
371 def test_reset_steps(self) -> None:
372 """Test that assigning steps from one graph to another transfers the
373 sharding dimensions.
374 """
375 new_graph = PipelineGraph()
376 new_graph.add_task_nodes(self.graph.tasks.values(), parent=self.graph)
377 self.graph.add_task_subset("step1", {"a"})
378 self.graph.add_task_subset("step2", {"b"})
379 self.graph.steps = ["step1"]
380 # These dimensions are not valid for the task dimensions we have, but
381 # that shouldn't be a problem until we try to resolve them.
382 self.graph.steps.append("step2", dimensions=self.dimensions.conform(["visit"]))
383 new_graph.steps = self.graph.steps
384 with self.assertRaises(InvalidStepsError):
385 new_graph.resolve(MockRegistry(self.dimensions, {}))
387 def test_invalid_steps_repeated_task(self) -> None:
388 """Test step definitions that are invalid because a task appears in
389 more than one step.
390 """
391 self.graph.add_task_subset("step1", {"a"})
392 self.graph.add_task_subset("step2", {"a", "b"})
393 self.graph.steps = ["step1", "step2"]
394 with self.assertRaises(InvalidStepsError):
395 self.graph.resolve(MockRegistry(self.dimensions, {}))
397 def test_invalid_steps_missing_task(self) -> None:
398 """Test step definitions that are invalid because a task appears in
399 more than one step.
400 """
401 self.graph.add_task_subset("step1", {"a"})
402 self.graph.steps = ["step1"]
403 with self.assertRaises(InvalidStepsError):
404 self.graph.resolve(MockRegistry(self.dimensions, {}))
406 def test_invalid_steps_bad_order(self) -> None:
407 """Test step definitions that are invalid because they are inconsistent
408 with the task flow.
409 """
410 self.graph.add_task_subset("step1", {"b"})
411 self.graph.add_task_subset("step2", {"a"})
412 self.graph.steps = ["step1", "step2"]
413 with self.assertRaises(InvalidStepsError):
414 self.graph.resolve(MockRegistry(self.dimensions, {}))
416 def test_invalid_steps_not_a_subset(self) -> None:
417 """Test step definitions that are invalid because they reference a
418 label that is not a task subset.
419 """
420 self.graph.add_task_subset("step1", {"b"})
421 self.graph.add_task_subset("step2", {"a"})
422 self.graph.steps = ["step1", "step2", "step3"]
423 with self.assertRaises(PipelineGraphError):
424 self.graph.steps.get_dimensions("step3")
425 with self.assertRaises(InvalidStepsError):
426 self.graph.resolve(MockRegistry(self.dimensions, {}))
428 def test_invalid_steps_bad_task_dimensions(self) -> None:
429 """Test step definitions that are invalid because the step dimensions
430 (for sharding) are incompatible with task dimensions.
431 """
432 # Resolve up-front so methods below have a DimensionUniverse; this
433 # triggers additional code to check dimensions as early as possible.
434 self.graph.resolve(MockRegistry(self.dimensions, {}))
435 self.graph.add_task_subset("step1", {"a"})
436 self.graph.add_task_subset("step2", {"b"})
437 with self.assertRaises(PipelineGraphError):
438 # Only steps can have dimensions, and this isn't a step yet.
439 self.graph.steps.set_dimensions("step2", {"visit"})
440 self.graph.steps = ["step1", "step2"]
441 self.graph.task_subsets["step2"].dimensions = {"visit"}
442 with self.assertRaises(InvalidStepsError):
443 self.graph.resolve(MockRegistry(self.dimensions, {}))
445 def test_invalid_steps_bad_dataset_type_dimensions(self) -> None:
446 """Test step definitions that are invalid because the step dimensions
447 (for sharding) are incompatible with output dataset type dimensions.
448 """
449 # This task includes an output that does not have all of the dimensions
450 # of the task itself, which is probably a malformed task and may be
451 # banned earlier in the future. At present it is not banned, so the
452 # step validation needs to check for consistency on these output
453 # dataset types as well, and we need to test that.
454 c_config = DynamicTestPipelineTaskConfig()
455 c_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
456 c_config.outputs["output2"] = DynamicConnectionConfig(dataset_type_name="output_2")
457 c_config.dimensions = ["detector"]
458 self.graph.add_task("c", DynamicTestPipelineTask, c_config)
459 self.graph.add_task_subset("step1", {"a", "b"})
460 self.graph.add_task_subset("step2", {"c"})
461 self.graph.steps = ["step1", "step2"]
462 self.graph.task_subsets["step2"].dimensions = {"detector"}
463 with self.assertRaises(InvalidStepsError):
464 self.graph.resolve(MockRegistry(self.dimensions, {}))
466 def check_base_accessors(self, graph: PipelineGraph) -> None:
467 """Run parameterized tests that check attribute access, iteration, and
468 simple methods.
470 The given graph must be unchanged from the one defined in `setUp`,
471 other than sorting.
472 """
473 self.assertEqual(graph.description, self.description)
474 self.assertEqual(graph.tasks.keys(), {"a", "b"})
475 self.assertEqual(
476 graph.dataset_types.keys(),
477 {
478 "schema",
479 "input_1",
480 "intermediate_1",
481 "output_1",
482 "a_config",
483 "a_log",
484 "a_metadata",
485 "b_config",
486 "b_log",
487 "b_metadata",
488 },
489 )
490 self.assertEqual(graph.task_subsets.keys(), {"only_b"})
491 self.assertEqual(
492 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)},
493 {
494 (
495 NodeKey(NodeType.DATASET_TYPE, "input_1"),
496 NodeKey(NodeType.TASK, "a"),
497 "input_1 -> a (input1)",
498 ),
499 (
500 NodeKey(NodeType.TASK, "a"),
501 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
502 "a -> intermediate_1 (output1)",
503 ),
504 (
505 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
506 NodeKey(NodeType.TASK, "b"),
507 "intermediate_1 -> b (input1)",
508 ),
509 (
510 NodeKey(NodeType.TASK, "b"),
511 NodeKey(NodeType.DATASET_TYPE, "output_1"),
512 "b -> output_1 (output1)",
513 ),
514 (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"),
515 (
516 NodeKey(NodeType.TASK, "a"),
517 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
518 "a -> a_metadata (_metadata)",
519 ),
520 (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"),
521 (
522 NodeKey(NodeType.TASK, "b"),
523 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
524 "b -> b_metadata (_metadata)",
525 ),
526 },
527 )
528 self.assertEqual(
529 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)},
530 {
531 (
532 NodeKey(NodeType.TASK_INIT, "a"),
533 NodeKey(NodeType.DATASET_TYPE, "schema"),
534 "a -> schema (out_schema)",
535 ),
536 (
537 NodeKey(NodeType.DATASET_TYPE, "schema"),
538 NodeKey(NodeType.TASK_INIT, "b"),
539 "schema -> b (in_schema)",
540 ),
541 (
542 NodeKey(NodeType.TASK_INIT, "a"),
543 NodeKey(NodeType.DATASET_TYPE, "a_config"),
544 "a -> a_config (_config)",
545 ),
546 (
547 NodeKey(NodeType.TASK_INIT, "b"),
548 NodeKey(NodeType.DATASET_TYPE, "b_config"),
549 "b -> b_config (_config)",
550 ),
551 },
552 )
553 self.assertEqual(
554 {(node_type, name) for node_type, name, _ in graph.iter_nodes()},
555 {
556 NodeKey(NodeType.TASK, "a"),
557 NodeKey(NodeType.TASK, "b"),
558 NodeKey(NodeType.TASK_INIT, "a"),
559 NodeKey(NodeType.TASK_INIT, "b"),
560 NodeKey(NodeType.DATASET_TYPE, "schema"),
561 NodeKey(NodeType.DATASET_TYPE, "input_1"),
562 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
563 NodeKey(NodeType.DATASET_TYPE, "output_1"),
564 NodeKey(NodeType.DATASET_TYPE, "a_config"),
565 NodeKey(NodeType.DATASET_TYPE, "a_log"),
566 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
567 NodeKey(NodeType.DATASET_TYPE, "b_config"),
568 NodeKey(NodeType.DATASET_TYPE, "b_log"),
569 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
570 },
571 )
572 self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"})
573 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"})
574 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"})
575 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set())
576 self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"})
577 self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"})
578 self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set())
580 self.assertIsNone(graph.producing_edge_of("input_1"))
581 self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a")
582 self.assertEqual(graph.producing_edge_of("output_1").task_label, "b")
583 self.assertIsNone(graph.producer_of("input_1"))
584 self.assertEqual(graph.producer_of("intermediate_1").label, "a")
585 self.assertEqual(graph.producer_of("output_1").label, "b")
587 self.assertEqual(graph.inputs_of("a").keys(), {"input_1"})
588 self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"})
589 self.assertEqual(graph.inputs_of("a", init=True).keys(), set())
590 self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"})
591 self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"})
592 self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"})
593 self.assertEqual(
594 graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"}
595 )
596 self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"})
597 self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"})
598 self.assertEqual(
599 graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"}
600 )
601 self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"})
602 self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set())
604 self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks="))
605 self.assertEqual(
606 repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}"
607 )
608 for task_node in graph.tasks.values():
609 for edge in itertools.chain(task_node.iter_all_inputs(), task_node.iter_all_outputs()):
610 self.assertEqual(task_node.get_edge(edge.connection_name), edge)
611 for edge in itertools.chain(task_node.init.iter_all_inputs(), task_node.init.iter_all_outputs()):
612 self.assertEqual(task_node.init.get_edge(edge.connection_name), edge)
614 def check_sorted(self, graph: PipelineGraph) -> None:
615 """Run a battery of tests on a PipelineGraph that must be
616 deterministically sorted.
618 The given graph must be unchanged from the one defined in `setUp`,
619 other than sorting.
620 """
621 self.assertTrue(graph.has_been_sorted)
622 self.assertEqual(
623 [(node_type, name) for node_type, name, _ in graph.iter_nodes()],
624 [
625 # We only advertise that the order is topological and
626 # deterministic, so this test is slightly over-specified; there
627 # are other orders that are consistent with our guarantees.
628 NodeKey(NodeType.DATASET_TYPE, "input_1"),
629 NodeKey(NodeType.TASK_INIT, "a"),
630 NodeKey(NodeType.DATASET_TYPE, "a_config"),
631 NodeKey(NodeType.DATASET_TYPE, "schema"),
632 NodeKey(NodeType.TASK_INIT, "b"),
633 NodeKey(NodeType.DATASET_TYPE, "b_config"),
634 NodeKey(NodeType.TASK, "a"),
635 NodeKey(NodeType.DATASET_TYPE, "a_log"),
636 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
637 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
638 NodeKey(NodeType.TASK, "b"),
639 NodeKey(NodeType.DATASET_TYPE, "b_log"),
640 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
641 NodeKey(NodeType.DATASET_TYPE, "output_1"),
642 ],
643 )
644 # Most users should only care that the tasks and dataset types are
645 # topologically sorted.
646 self.assertEqual(list(graph.tasks), ["a", "b"])
647 self.assertEqual(
648 list(graph.dataset_types),
649 [
650 "input_1",
651 "a_config",
652 "schema",
653 "b_config",
654 "a_log",
655 "a_metadata",
656 "intermediate_1",
657 "b_log",
658 "b_metadata",
659 "output_1",
660 ],
661 )
662 # __str__ and __repr__ of course work on unsorted mapping views, too,
663 # but the order of elements is then nondeterministic and hard to test.
664 self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})")
665 self.assertEqual(
666 repr(self.graph.dataset_types),
667 (
668 "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, "
669 "intermediate_1, b_log, b_metadata, output_1})"
670 ),
671 )
673 def check_make_xgraph(
674 self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True
675 ) -> None:
676 """Check that the given graph exports as expected to networkx.
678 The given graph must be unchanged from the one defined in `setUp`,
679 other than being resolved (if ``resolved=True``) or round-tripped
680 through serialization without tasks being imported (if
681 ``imported_and_configured=False``).
682 """
683 xgraph = graph.make_xgraph()
684 expected_edges = (
685 {edge.key for edge in graph.iter_edges()}
686 | {edge.key for edge in graph.iter_edges(init=True)}
687 | {
688 (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME),
689 (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME),
690 }
691 )
692 test_edges = set(xgraph.edges)
693 self.assertEqual(test_edges, expected_edges)
694 expected_nodes = {
695 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node(
696 "a", resolved, imported_and_configured=imported_and_configured
697 ),
698 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node(
699 "a", resolved, imported_and_configured=imported_and_configured
700 ),
701 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node(
702 "b", resolved, imported_and_configured=imported_and_configured
703 ),
704 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node(
705 "b", resolved, imported_and_configured=imported_and_configured
706 ),
707 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
708 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
709 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
710 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
711 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
712 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
713 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
714 "schema", resolved, is_initial_query_constraint=False
715 ),
716 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
717 "input_1", resolved, is_initial_query_constraint=True
718 ),
719 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
720 "intermediate_1", resolved, is_initial_query_constraint=False
721 ),
722 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
723 "output_1", resolved, is_initial_query_constraint=False
724 ),
725 }
726 test_nodes = dict(xgraph.nodes.items())
727 self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys()))
728 for key, expected_node in expected_nodes.items():
729 test_node = test_nodes[key]
730 self.assertEqual(expected_node, test_node, key)
732 def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
733 """Check that the given graph's init-only or runtime subset exports as
734 expected to networkx.
736 The given graph must be unchanged from the one defined in `setUp`,
737 other than being resolved (if ``resolved=True``).
738 """
739 run_xgraph = graph.make_bipartite_xgraph()
740 self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()})
741 self.assertEqual(
742 dict(run_xgraph.nodes.items()),
743 {
744 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
745 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
746 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
747 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
748 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
749 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
750 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
751 "input_1", resolved, is_initial_query_constraint=True
752 ),
753 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
754 "intermediate_1", resolved, is_initial_query_constraint=False
755 ),
756 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
757 "output_1", resolved, is_initial_query_constraint=False
758 ),
759 },
760 )
761 init_xgraph = graph.make_bipartite_xgraph(
762 init=True,
763 )
764 self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)})
765 self.assertEqual(
766 dict(init_xgraph.nodes.items()),
767 {
768 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
769 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
770 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
771 "schema", resolved, is_initial_query_constraint=False
772 ),
773 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
774 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
775 },
776 )
778 def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
779 """Check that the given graph's task-only projection exports as
780 expected to networkx.
782 The given graph must be unchanged from the one defined in `setUp`,
783 other than being resolved (if ``resolved=True``).
784 """
785 run_xgraph = graph.make_task_xgraph()
786 self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))})
787 self.assertEqual(
788 dict(run_xgraph.nodes.items()),
789 {
790 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
791 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
792 },
793 )
794 init_xgraph = graph.make_task_xgraph(
795 init=True,
796 )
797 self.assertEqual(
798 set(init_xgraph.edges),
799 {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))},
800 )
801 self.assertEqual(
802 dict(init_xgraph.nodes.items()),
803 {
804 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
805 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
806 },
807 )
809 def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
810 """Check that the given graph's dataset-type-only projection exports as
811 expected to networkx.
813 The given graph must be unchanged from the one defined in `setUp`,
814 other than being resolved (if ``resolved=True``).
815 """
816 run_xgraph = graph.make_dataset_type_xgraph()
817 self.assertEqual(
818 set(run_xgraph.edges),
819 {
820 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")),
821 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")),
822 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")),
823 (
824 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
825 NodeKey(NodeType.DATASET_TYPE, "output_1"),
826 ),
827 (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")),
828 (
829 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
830 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
831 ),
832 },
833 )
834 self.assertEqual(
835 dict(run_xgraph.nodes.items()),
836 {
837 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
838 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
839 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
840 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
841 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
842 "input_1", resolved, is_initial_query_constraint=True
843 ),
844 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
845 "intermediate_1", resolved, is_initial_query_constraint=False
846 ),
847 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
848 "output_1", resolved, is_initial_query_constraint=False
849 ),
850 },
851 )
852 init_xgraph = graph.make_dataset_type_xgraph(init=True)
853 self.assertEqual(
854 set(init_xgraph.edges),
855 {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))},
856 )
857 self.assertEqual(
858 dict(init_xgraph.nodes.items()),
859 {
860 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
861 "schema", resolved, is_initial_query_constraint=False
862 ),
863 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
864 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
865 },
866 )
868 def get_expected_task_node(
869 self, label: str, resolved: bool, imported_and_configured: bool = True
870 ) -> dict[str, Any]:
871 """Construct a networkx-export task node for comparison."""
872 result = self.get_expected_task_init_node(
873 label, resolved, imported_and_configured=imported_and_configured
874 )
875 if resolved:
876 result["dimensions"] = self.dimensions.empty
877 result["raw_dimensions"] = frozenset()
878 return result
880 def get_expected_task_init_node(
881 self, label: str, resolved: bool, imported_and_configured: bool = True
882 ) -> dict[str, Any]:
883 """Construct a networkx-export task init for comparison."""
884 result = {
885 "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask",
886 "bipartite": 1,
887 }
888 if imported_and_configured:
889 result["task_class"] = DynamicTestPipelineTask
890 result["config"] = getattr(self, f"{label}_config")
891 return result
893 def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]:
894 """Construct a networkx-export init-output config dataset type node for
895 comparison.
896 """
897 if not resolved:
898 return {"bipartite": 0}
899 else:
900 return {
901 "dataset_type": DatasetType(
902 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
903 self.dimensions.empty,
904 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
905 ),
906 "is_initial_query_constraint": False,
907 "is_prerequisite": False,
908 "dimensions": self.dimensions.empty,
909 "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
910 "bipartite": 0,
911 }
913 def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]:
914 """Construct a networkx-export output log dataset type node for
915 comparison.
916 """
917 if not resolved:
918 return {"bipartite": 0}
919 else:
920 return {
921 "dataset_type": DatasetType(
922 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
923 self.dimensions.empty,
924 acc.LOG_OUTPUT_STORAGE_CLASS,
925 ),
926 "is_initial_query_constraint": False,
927 "is_prerequisite": False,
928 "dimensions": self.dimensions.empty,
929 "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS,
930 "bipartite": 0,
931 }
933 def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]:
934 """Construct a networkx-export output metadata dataset type node for
935 comparison.
936 """
937 if not resolved:
938 return {"bipartite": 0}
939 else:
940 return {
941 "dataset_type": DatasetType(
942 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
943 self.dimensions.empty,
944 acc.METADATA_OUTPUT_STORAGE_CLASS,
945 ),
946 "is_initial_query_constraint": False,
947 "is_prerequisite": False,
948 "dimensions": self.dimensions.empty,
949 "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS,
950 "bipartite": 0,
951 }
953 def get_expected_connection_node(
954 self, name: str, resolved: bool, *, is_initial_query_constraint: bool
955 ) -> dict[str, Any]:
956 """Construct a networkx-export dataset type node for comparison."""
957 if not resolved:
958 return {"bipartite": 0}
959 else:
960 return {
961 "dataset_type": DatasetType(
962 name,
963 self.dimensions.empty,
964 get_mock_name("StructuredDataDict"),
965 ),
966 "is_initial_query_constraint": is_initial_query_constraint,
967 "is_prerequisite": False,
968 "dimensions": self.dimensions.empty,
969 "storage_class_name": get_mock_name("StructuredDataDict"),
970 "bipartite": 0,
971 }
973 def test_construct_with_data_coordinate(self) -> None:
974 """Test constructing a graph with a DataCoordinate.
976 Since this creates a graph with DimensionUniverse, all tasks added to
977 it should have resolved dimensions, but not (yet) resolved dataset
978 types. We use that to test a few other operations in that state.
979 """
980 data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions)
981 graph = PipelineGraph(data_id=data_id)
982 self.assertEqual(graph.universe, self.dimensions)
983 self.assertEqual(graph.data_id, data_id)
984 graph.add_task("b1", DynamicTestPipelineTask, self.b_config)
985 self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty)
986 # Still can't group by dimensions, because the dataset types aren't
987 # resolved.
988 with self.assertRaises(UnresolvedGraphError):
989 graph.group_by_dimensions()
990 # Transferring a node from this graph to ``self.graph`` should
991 # unresolve the dimensions.
992 self.graph.add_task_nodes([graph.tasks["b1"]])
993 self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"])
994 self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions)
995 # Do the opposite transfer, which should resolve dimensions.
996 graph.add_task_nodes([self.graph.tasks["a"]])
997 self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"])
998 self.assertTrue(graph.tasks["a"].has_resolved_dimensions)
1000 def test_group_by_dimensions(self) -> None:
1001 """Test PipelineGraph.group_by_dimensions."""
1002 with self.assertRaises(UnresolvedGraphError):
1003 self.graph.group_by_dimensions()
1004 self.a_config.dimensions = ["visit"]
1005 self.a_config.outputs["output1"].dimensions = ["visit"]
1006 self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig(
1007 dataset_type_name="prereq_1",
1008 multiple=True,
1009 dimensions=["htm7"],
1010 is_calibration=True,
1011 )
1012 self.b_config.dimensions = ["htm7"]
1013 self.b_config.inputs["input1"].dimensions = ["visit"]
1014 self.b_config.inputs["input1"].multiple = True
1015 self.b_config.outputs["output1"].dimensions = ["htm7"]
1016 self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config)
1017 self.graph.resolve(MockRegistry(self.dimensions, {}))
1018 visit_dims = self.dimensions.conform(["visit"])
1019 htm7_dims = self.dimensions.conform(["htm7"])
1020 expected = {
1021 self.dimensions.empty: (
1022 {},
1023 {
1024 "schema": self.graph.dataset_types["schema"],
1025 "input_1": self.graph.dataset_types["input_1"],
1026 "a_config": self.graph.dataset_types["a_config"],
1027 "b_config": self.graph.dataset_types["b_config"],
1028 },
1029 ),
1030 visit_dims: (
1031 {"a": self.graph.tasks["a"]},
1032 {
1033 "a_log": self.graph.dataset_types["a_log"],
1034 "a_metadata": self.graph.dataset_types["a_metadata"],
1035 "intermediate_1": self.graph.dataset_types["intermediate_1"],
1036 },
1037 ),
1038 htm7_dims: (
1039 {"b": self.graph.tasks["b"]},
1040 {
1041 "b_log": self.graph.dataset_types["b_log"],
1042 "b_metadata": self.graph.dataset_types["b_metadata"],
1043 "output_1": self.graph.dataset_types["output_1"],
1044 },
1045 ),
1046 }
1047 self.assertEqual(self.graph.group_by_dimensions(), expected)
1048 expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"]
1049 self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected)
1050 self.assertEqual(self.graph.get_all_dimensions(), visit_dims | htm7_dims)
1052 def test_add_and_remove(self) -> None:
1053 """Tests for adding and removing tasks and task subsets from a
1054 PipelineGraph.
1055 """
1056 original = self.graph.copy()
1057 # Can't remove a task while it's still in a subset.
1058 with self.assertRaises(PipelineGraphError):
1059 self.graph.remove_tasks(["b"], drop_from_subsets=False)
1060 self.assertEqual(original.diff_tasks(self.graph), [])
1061 # ...unless you remove the subset.
1062 self.graph.remove_task_subset("only_b")
1063 self.assertFalse(self.graph.task_subsets)
1064 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False)
1065 self.assertFalse(referencing_subsets)
1066 self.assertEqual(self.graph.tasks.keys(), {"a"})
1067 self.assertEqual(
1068 original.diff_tasks(self.graph),
1069 ["Pipelines have different tasks: A & ~B = ['b'], B & ~A = []."],
1070 )
1071 # Add that task back in.
1072 self.graph.add_task_nodes([b])
1073 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
1074 # Add the subset back in.
1075 self.graph.add_task_subset("only_b", {"b"})
1076 self.assertEqual(self.graph.task_subsets.keys(), {"only_b"})
1077 # Add a task to the subset and then remove it.
1078 self.graph.task_subsets["only_b"].add("a")
1079 self.assertEqual(self.graph.task_subsets["only_b"], {"a", "b"})
1080 self.assertEqual(self.graph.task_subsets["only_b"] & {"b"}, {"b"})
1081 self.graph.task_subsets["only_b"].remove("a")
1082 self.assertEqual(self.graph.task_subsets["only_b"], {"b"})
1083 with self.assertRaises(PipelineGraphError):
1084 self.graph.task_subsets["only_b"].add("c")
1085 # Resolve the graph's dataset types and task dimensions.
1086 self.graph.resolve(MockRegistry(self.dimensions, {}))
1087 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
1088 self.assertTrue(self.graph.dataset_types.is_resolved("output_1"))
1089 self.assertTrue(self.graph.dataset_types.is_resolved("schema"))
1090 self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1"))
1091 # Remove the task while removing it from the subset automatically. This
1092 # should also unresolve (only) the referenced dataset types and drop
1093 # any datasets no longer attached to any task.
1094 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
1095 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True)
1096 self.assertEqual(referencing_subsets, {"only_b"})
1097 self.assertEqual(self.graph.tasks.keys(), {"a"})
1098 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
1099 self.assertNotIn("output1", self.graph.dataset_types)
1100 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
1101 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
1103 def test_reconfigure(self) -> None:
1104 """Tests for PipelineGraph.reconfigure."""
1105 original = self.graph.copy()
1106 self.graph.resolve(MockRegistry(self.dimensions, {}))
1107 self.b_config.outputs["output1"].storage_class = "TaskMetadata"
1108 with self.assertRaises(ValueError):
1109 # Can't check and assume together.
1110 self.graph.reconfigure_tasks(
1111 b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True
1112 )
1113 # Check that graph is unchanged after error.
1114 self.check_base_accessors(self.graph)
1115 with self.assertRaises(EdgesChangedError):
1116 self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True)
1117 self.check_base_accessors(self.graph)
1118 self.assertEqual(original.diff_tasks(self.graph), [])
1119 # Make a change that does affect edges; this will unresolve most
1120 # dataset types.
1121 self.graph.reconfigure_tasks(b=self.b_config)
1122 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
1123 self.assertFalse(self.graph.dataset_types.is_resolved("output_1"))
1124 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
1125 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
1126 self.assertEqual(
1127 original.diff_tasks(self.graph),
1128 [
1129 "Output b.output1 has storage class '_mock_StructuredDataDict' in A, "
1130 "but '_mock_TaskMetadata' in B."
1131 ],
1132 )
1133 # Resolving again will pick up the new storage class
1134 self.graph.resolve(MockRegistry(self.dimensions, {}))
1135 self.assertEqual(
1136 self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata")
1137 )
1139 def check_visualization(self, graph: PipelineGraph, expected: str, **kwargs: Any) -> None:
1140 """Run pipeline graph visualization with the given kwargs and check
1141 that the output is the given expected string.
1143 Parameters
1144 ----------
1145 graph : `lsst.pipe.base.pipeline_graph.PipelineGraph`
1146 Pipeline graph to visualize.
1147 expected : `str`
1148 Expected output of the visualization. Will be passed through
1149 `textwrap.dedent`, to allow it to be written with triple-quotes.
1150 **kwargs
1151 Forwarded to `lsst.pipe.base.pipeline_graph.visualization.show`.
1152 """
1153 stream = io.StringIO()
1154 visualization.show(graph, stream, **kwargs)
1155 self.assertEqual(textwrap.dedent(expected), stream.getvalue())
1157 def test_unresolved_visualization(self) -> None:
1158 """Test pipeline graph text-based visualization on unresolved
1159 graphs.
1160 """
1161 self.check_visualization(
1162 self.graph,
1163 """
1164 ■ a
1165 │
1166 ■ b
1167 """,
1168 merge_input_trees=0,
1169 merge_output_trees=0,
1170 merge_intermediates=False,
1171 )
1172 self.check_visualization(
1173 self.graph,
1174 """
1175 ○ input_1
1176 │
1177 ■ a
1178 │
1179 ○ intermediate_1
1180 │
1181 ■ b
1182 │
1183 ○ output_1
1184 """,
1185 dataset_types=True,
1186 )
1188 def test_resolved_visualization(self) -> None:
1189 """Test pipeline graph text-based visualization on resolved graphs."""
1190 self.graph.resolve(MockRegistry(dimensions=self.dimensions, dataset_types={}))
1191 self.check_visualization(
1192 self.graph,
1193 """
1194 ■ a: {} DynamicTestPipelineTask
1195 │
1196 ■ b: {} DynamicTestPipelineTask
1197 """,
1198 task_classes="concise",
1199 merge_input_trees=0,
1200 merge_output_trees=0,
1201 merge_intermediates=False,
1202 )
1203 self.check_visualization(
1204 self.graph,
1205 """
1206 ○ input_1: {} _mock_StructuredDataDict
1207 │
1208 ■ a: {} lsst.pipe.base.tests.mocks.DynamicTestPipelineTask
1209 │
1210 ○ intermediate_1: {} _mock_StructuredDataDict
1211 │
1212 ■ b: {} lsst.pipe.base.tests.mocks.DynamicTestPipelineTask
1213 │
1214 ○ output_1: {} _mock_StructuredDataDict
1215 """,
1216 task_classes="full",
1217 dataset_types=True,
1218 )
1220 def test_select(self) -> None:
1221 """Test PipelineGraph.select_tasks."""
1222 # New task c is downstream of a and parallel (unrelated) to b.
1223 c_config = DynamicTestPipelineTaskConfig()
1224 c_config.inputs["input3"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
1225 c_config.outputs["output3"] = DynamicConnectionConfig(dataset_type_name="intermediate_2")
1226 self.graph.add_task("c", DynamicTestPipelineTask, c_config)
1227 # New task d is downstream of c and parallel (unrelated) to b.
1228 d_config = DynamicTestPipelineTaskConfig()
1229 d_config.inputs["input4"] = DynamicConnectionConfig(dataset_type_name="intermediate_2")
1230 d_config.outputs["output4"] = DynamicConnectionConfig(dataset_type_name="output_2")
1231 self.graph.add_task("d", DynamicTestPipelineTask, d_config)
1232 self.graph.add_task_subset("c_and_d", {"c", "d"})
1233 # Check identifiers
1234 self.check_expression("a", {"a"})
1235 self.check_expression("T:a", {"a"})
1236 self.check_expression("only_b", {"b"})
1237 self.check_expression("S:c_and_d", {"c", "d"})
1238 self.check_expression("output_1", {"b"})
1239 self.check_expression("input_1", set())
1240 self.check_expression("intermediate_1", {"a"})
1241 self.check_expression("D:intermediate_1", {"a"})
1242 # Check inversion, intersection, and union.
1243 self.check_expression("~a", {"b", "c", "d"})
1244 self.check_expression("~only_b", {"a", "c", "d"})
1245 self.check_expression("a | c_and_d", {"a", "c", "d"})
1246 self.check_expression("c & c_and_d", {"c"})
1247 self.check_expression("(a | b) & (b | c)", {"b"})
1248 # Check ancestors and descendants for task operands.
1249 self.check_expression("<a", set())
1250 self.check_expression("<b", {"a"})
1251 self.check_expression("<c", {"a"})
1252 self.check_expression("<d", {"a", "c"})
1253 self.check_expression("<=a", {"a"})
1254 self.check_expression("<=b", {"a", "b"})
1255 self.check_expression("<=c", {"a", "c"})
1256 self.check_expression("<=d", {"a", "c", "d"})
1257 self.check_expression(">a", {"b", "c", "d"})
1258 self.check_expression(">b", set())
1259 self.check_expression(">c", {"d"})
1260 self.check_expression(">d", set())
1261 self.check_expression(">=a", {"a", "b", "c", "d"})
1262 self.check_expression(">=b", {"b"})
1263 self.check_expression(">=c", {"c", "d"})
1264 self.check_expression(">=d", {"d"})
1265 # Check ancestors and descendants for dataset type operands.
1266 self.check_expression("<input_1", set())
1267 self.check_expression("<intermediate_1", set())
1268 self.check_expression("<intermediate_2", {"a"})
1269 self.check_expression("<output_1", {"a"})
1270 self.check_expression("<output_2", {"a", "c"})
1271 self.check_expression("<=input_1", set())
1272 self.check_expression("<=intermediate_1", {"a"})
1273 self.check_expression("<=intermediate_2", {"a", "c"})
1274 self.check_expression("<=output_1", {"a", "b"})
1275 self.check_expression("<=output_2", {"a", "c", "d"})
1276 self.check_expression(">input_1", {"a", "b", "c", "d"})
1277 self.check_expression(">intermediate_1", {"b", "c", "d"})
1278 self.check_expression(">intermediate_2", {"d"})
1279 self.check_expression(">output_1", set())
1280 self.check_expression(">output_2", set())
1281 self.check_expression(">=input_1", {"a", "b", "c", "d"})
1282 self.check_expression(">=intermediate_1", {"a", "b", "c", "d"})
1283 self.check_expression(">=intermediate_2", {"c", "d"})
1284 self.check_expression(">=output_1", {"b"})
1285 self.check_expression(">=output_2", {"d"})
1286 # Test that init datasets can be used with non-inclusive descendant
1287 # searches (i.e. as long as we're not trying to select the
1288 # pre-exec-init process as a task, it's fine).
1289 self.check_expression(">schema", {"b"})
1290 # Check exceptions for invalid expressions.
1291 with self.assertRaises(InvalidExpressionError):
1292 self.graph.select_tasks("")
1293 with self.assertRaises(InvalidExpressionError):
1294 self.graph.select_tasks("schema")
1295 with self.assertRaises(InvalidExpressionError):
1296 self.graph.select_tasks("<schema")
1297 with self.assertRaises(InvalidExpressionError):
1298 self.graph.select_tasks("<=schema")
1299 with self.assertRaises(InvalidExpressionError):
1300 self.graph.select_tasks(">=schema")
1301 with self.assertRaises(InvalidExpressionError):
1302 self.graph.select_tasks("<only_b")
1303 with self.assertRaises(InvalidExpressionError):
1304 self.graph.select_tasks(">c_and_d")
1305 with self.assertRaises(InvalidExpressionError):
1306 self.graph.select_tasks("q")
1307 with self.assertRaises(InvalidExpressionError):
1308 self.graph.select_tasks("T:q")
1309 with self.assertRaises(InvalidExpressionError):
1310 self.graph.select_tasks("D:q")
1311 with self.assertRaises(InvalidExpressionError):
1312 self.graph.select_tasks("S:q")
1313 # Add task and subset with names that duplicate a dataset type, check
1314 # name resolution (note that subset/task conflicts are not allowed).
1315 # New task d is downstream of c and parallel (unrelated) to b.
1316 e_config = DynamicTestPipelineTaskConfig()
1317 e_config.inputs["input5"] = DynamicConnectionConfig(dataset_type_name="intermediate_2")
1318 e_config.outputs["output5"] = DynamicConnectionConfig(dataset_type_name="e")
1319 e_config.outputs["output6"] = DynamicConnectionConfig(dataset_type_name="f")
1320 self.graph.add_task("e", DynamicTestPipelineTask, e_config)
1321 self.graph.add_task_subset("f", {"a"})
1322 with self.assertRaises(InvalidExpressionError):
1323 self.graph.select_tasks("e") # ambiguous
1324 with self.assertRaises(InvalidExpressionError):
1325 self.graph.select_tasks("f") # ambiguous
1326 self.check_expression("T:e", {"e"})
1327 self.check_expression("D:e", {"e"})
1328 self.check_expression("S:f", {"a"})
1329 self.check_expression("D:f", {"e"})
1331 def check_expression(self, expression: str, expectation: set[str]) -> None:
1332 """Test that `PipelineGraph.select` and `PipelineGraph.select_tasks`
1333 yield the given result for the given expression.
1334 """
1335 self.assertEqual(self.graph.select_tasks(expression), expectation)
1336 self.assertEqual(self.graph.select(expression).tasks.keys(), expectation)
1339def _have_example_storage_classes() -> bool:
1340 """Check whether some storage classes work as expected.
1342 Given that these have registered converters, it shouldn't actually be
1343 necessary to import those types in order to determine that they're
1344 convertible, but the storage class machinery is implemented such that types
1345 that can't be imported can't be converted, and while that's inconvenient
1346 here it's totally fine in non-testing scenarios where you only care about a
1347 storage class if you can actually use it.
1348 """
1349 getter = StorageClassFactory().getStorageClass
1350 return (
1351 getter("ArrowTable").can_convert(getter("ArrowAstropy"))
1352 and getter("ArrowAstropy").can_convert(getter("ArrowTable"))
1353 and getter("ArrowTable").can_convert(getter("DataFrame"))
1354 and getter("DataFrame").can_convert(getter("ArrowTable"))
1355 )
1358class PipelineGraphResolveTestCase(unittest.TestCase):
1359 """More extensive tests for PipelineGraph.resolve and its primary helper
1360 methods.
1362 These are in a separate TestCase because they utilize a different `setUp`
1363 from the rest of the `PipelineGraph` tests.
1364 """
1366 def setUp(self) -> None:
1367 self.a_config = DynamicTestPipelineTaskConfig()
1368 self.b_config = DynamicTestPipelineTaskConfig()
1369 self.dimensions = DimensionUniverse()
1370 self.maxDiff = None
1372 def make_graph(self) -> PipelineGraph:
1373 graph = PipelineGraph()
1374 graph.add_task("a", DynamicTestPipelineTask, self.a_config)
1375 graph.add_task("b", DynamicTestPipelineTask, self.b_config)
1376 return graph
1378 def test_prerequisite_inconsistency(self) -> None:
1379 """Test that we raise an exception when one edge defines a dataset type
1380 as a prerequisite and another does not.
1382 This test will hopefully someday go away (along with
1383 `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation
1384 algorithm becomes more flexible.
1385 """
1386 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
1387 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
1388 graph = self.make_graph()
1389 with self.assertRaises(ConnectionTypeConsistencyError):
1390 graph.resolve(MockRegistry(self.dimensions, {}))
1392 def test_prerequisite_inconsistency_reversed(self) -> None:
1393 """Same as `test_prerequisite_inconsistency`, with the order the edges
1394 are added to the graph reversed.
1395 """
1396 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
1397 self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
1398 graph = self.make_graph()
1399 with self.assertRaises(ConnectionTypeConsistencyError):
1400 graph.resolve(MockRegistry(self.dimensions, {}))
1402 def test_prerequisite_output(self) -> None:
1403 """Test that we raise an exception when one edge defines a dataset type
1404 as a prerequisite but another defines it as an output.
1405 """
1406 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
1407 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1408 graph = self.make_graph()
1409 with self.assertRaises(ConnectionTypeConsistencyError):
1410 graph.resolve(MockRegistry(self.dimensions, {}))
1412 def test_skypix_missing(self) -> None:
1413 """Test that we raise an exception when one edge uses the "skypix"
1414 dimension as a placeholder but the dataset type is not registered.
1415 """
1416 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
1417 dataset_type_name="d", dimensions={"skypix"}
1418 )
1419 graph = self.make_graph()
1420 with self.assertRaises(MissingDatasetTypeError):
1421 graph.resolve(MockRegistry(self.dimensions, {}))
1423 def test_skypix_inconsistent(self) -> None:
1424 """Test that we raise an exception when one edge uses the "skypix"
1425 dimension as a placeholder but the rest of the dimensions are
1426 inconsistent with the registered dataset type.
1427 """
1428 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
1429 dataset_type_name="d", dimensions={"skypix", "visit"}
1430 )
1431 graph = self.make_graph()
1432 with self.assertRaises(IncompatibleDatasetTypeError):
1433 graph.resolve(
1434 MockRegistry(
1435 self.dimensions,
1436 {
1437 "d": DatasetType(
1438 "d",
1439 dimensions=self.dimensions.conform(["htm7"]),
1440 storageClass="StructuredDataDict",
1441 )
1442 },
1443 )
1444 )
1445 with self.assertRaises(IncompatibleDatasetTypeError):
1446 graph.resolve(
1447 MockRegistry(
1448 self.dimensions,
1449 {
1450 "d": DatasetType(
1451 "d",
1452 dimensions=self.dimensions.conform(["htm7", "visit", "skymap"]),
1453 storageClass="StructuredDataDict",
1454 )
1455 },
1456 )
1457 )
1459 def test_duplicate_outputs(self) -> None:
1460 """Test that we raise an exception when a dataset type node would have
1461 two write edges.
1462 """
1463 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1464 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1465 graph = self.make_graph()
1466 with self.assertRaises(DuplicateOutputError):
1467 graph.resolve(MockRegistry(self.dimensions, {}))
1469 def test_component_of_unregistered_parent(self) -> None:
1470 """Test that we raise an exception when a component dataset type's
1471 parent is not registered.
1472 """
1473 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1474 graph = self.make_graph()
1475 with self.assertRaises(MissingDatasetTypeError):
1476 graph.resolve(MockRegistry(self.dimensions, {}))
1478 def test_undefined_component(self) -> None:
1479 """Test that we raise an exception when a component dataset type's
1480 parent is registered, but its storage class does not have that
1481 component.
1482 """
1483 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1484 graph = self.make_graph()
1485 with self.assertRaises(IncompatibleDatasetTypeError):
1486 graph.resolve(
1487 MockRegistry(
1488 self.dimensions,
1489 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1490 )
1491 )
1493 @unittest.skipUnless(
1494 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1495 )
1496 def test_bad_component_storage_class(self) -> None:
1497 """Test that we raise an exception when a component dataset type's
1498 parent is registered, but does not have that component.
1499 """
1500 self.a_config.inputs["i"] = DynamicConnectionConfig(
1501 dataset_type_name="d.schema", storage_class="StructuredDataDict"
1502 )
1503 graph = self.make_graph()
1504 with self.assertRaises(IncompatibleDatasetTypeError):
1505 graph.resolve(
1506 MockRegistry(
1507 self.dimensions,
1508 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))},
1509 )
1510 )
1512 def test_input_storage_class_incompatible_with_registry(self) -> None:
1513 """Test that we raise an exception when an input connection's storage
1514 class is incompatible with the registry definition.
1515 """
1516 self.a_config.inputs["i"] = DynamicConnectionConfig(
1517 dataset_type_name="d", storage_class="StructuredDataList"
1518 )
1519 graph = self.make_graph()
1520 with self.assertRaises(IncompatibleDatasetTypeError):
1521 graph.resolve(
1522 MockRegistry(
1523 self.dimensions,
1524 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1525 )
1526 )
1528 def test_output_storage_class_incompatible_with_registry(self) -> None:
1529 """Test that we raise an exception when an output connection's storage
1530 class is incompatible with the registry definition.
1531 """
1532 self.a_config.outputs["o"] = DynamicConnectionConfig(
1533 dataset_type_name="d", storage_class="StructuredDataList"
1534 )
1535 graph = self.make_graph()
1536 with self.assertRaises(IncompatibleDatasetTypeError):
1537 graph.resolve(
1538 MockRegistry(
1539 self.dimensions,
1540 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1541 )
1542 )
1544 def test_input_storage_class_incompatible_with_output(self) -> None:
1545 """Test that we raise an exception when an input connection's storage
1546 class is incompatible with the storage class of the output connection.
1547 """
1548 self.a_config.outputs["o"] = DynamicConnectionConfig(
1549 dataset_type_name="d", storage_class="StructuredDataDict"
1550 )
1551 self.b_config.inputs["i"] = DynamicConnectionConfig(
1552 dataset_type_name="d", storage_class="StructuredDataList"
1553 )
1554 graph = self.make_graph()
1555 with self.assertRaises(IncompatibleDatasetTypeError):
1556 graph.resolve(MockRegistry(self.dimensions, {}))
1558 def test_ambiguous_storage_class(self) -> None:
1559 """Test that we raise an exception when two input connections define
1560 the same dataset with different storage classes (even compatible ones)
1561 and there is no output connection or registry definition to take
1562 precedence.
1563 """
1564 self.a_config.inputs["i"] = DynamicConnectionConfig(
1565 dataset_type_name="d", storage_class="StructuredDataDict"
1566 )
1567 self.b_config.inputs["i"] = DynamicConnectionConfig(
1568 dataset_type_name="d", storage_class="StructuredDataList"
1569 )
1570 graph = self.make_graph()
1571 with self.assertRaises(MissingDatasetTypeError):
1572 graph.resolve(MockRegistry(self.dimensions, {}))
1574 @unittest.skipUnless(
1575 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1576 )
1577 def test_inputs_compatible_with_registry(self) -> None:
1578 """Test successful resolution of a dataset type where input edges have
1579 different but compatible storage classes and the dataset type is
1580 already registered.
1581 """
1582 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1583 self.b_config.inputs["i"] = DynamicConnectionConfig(
1584 dataset_type_name="d", storage_class="ArrowAstropy"
1585 )
1586 graph = self.make_graph()
1587 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1588 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1589 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1590 a_i = graph.tasks["a"].inputs["i"]
1591 b_i = graph.tasks["b"].inputs["i"]
1592 self.assertEqual(
1593 a_i.adapt_dataset_type(dataset_type),
1594 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1595 )
1596 self.assertEqual(
1597 b_i.adapt_dataset_type(dataset_type),
1598 dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1599 )
1600 data_id = DataCoordinate.make_empty(self.dimensions)
1601 ref = DatasetRef(dataset_type, data_id, run="r")
1602 a_ref = a_i.adapt_dataset_ref(ref)
1603 b_ref = b_i.adapt_dataset_ref(ref)
1604 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1605 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1606 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1607 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1609 @unittest.skipUnless(
1610 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1611 )
1612 def test_output_compatible_with_registry(self) -> None:
1613 """Test successful resolution of a dataset type where an output edge
1614 has a different but compatible storage class from the dataset type
1615 already registered.
1616 """
1617 self.a_config.outputs["o"] = DynamicConnectionConfig(
1618 dataset_type_name="d", storage_class="ArrowTable"
1619 )
1620 graph = self.make_graph()
1621 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1622 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1623 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1624 a_o = graph.tasks["a"].outputs["o"]
1625 self.assertEqual(
1626 a_o.adapt_dataset_type(dataset_type),
1627 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1628 )
1629 data_id = DataCoordinate.make_empty(self.dimensions)
1630 ref = DatasetRef(dataset_type, data_id, run="r")
1631 a_ref = a_o.adapt_dataset_ref(ref)
1632 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1633 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1635 @unittest.skipUnless(
1636 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1637 )
1638 def test_inputs_compatible_with_output(self) -> None:
1639 """Test successful resolution of a dataset type where an input edge has
1640 a different but compatible storage class from the output edge, and
1641 the dataset type is not registered.
1642 """
1643 self.a_config.outputs["o"] = DynamicConnectionConfig(
1644 dataset_type_name="d", storage_class="ArrowTable"
1645 )
1646 self.b_config.inputs["i"] = DynamicConnectionConfig(
1647 dataset_type_name="d", storage_class="ArrowAstropy"
1648 )
1649 graph = self.make_graph()
1650 a_o = graph.tasks["a"].outputs["o"]
1651 b_i = graph.tasks["b"].inputs["i"]
1652 graph.resolve(MockRegistry(self.dimensions, {}))
1653 self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable"))
1654 self.assertEqual(
1655 a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1656 graph.dataset_types["d"].dataset_type,
1657 )
1658 self.assertEqual(
1659 b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1660 graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1661 )
1662 data_id = DataCoordinate.make_empty(self.dimensions)
1663 ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r")
1664 a_ref = a_o.adapt_dataset_ref(ref)
1665 b_ref = b_i.adapt_dataset_ref(ref)
1666 self.assertEqual(a_ref, ref)
1667 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1668 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1669 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1671 @unittest.skipUnless(
1672 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1673 )
1674 def test_component_resolved_by_input(self) -> None:
1675 """Test successful resolution of a component dataset type due to
1676 another input referencing the parent dataset type.
1677 """
1678 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1679 self.b_config.inputs["i"] = DynamicConnectionConfig(
1680 dataset_type_name="d.schema", storage_class="ArrowSchema"
1681 )
1682 graph = self.make_graph()
1683 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1684 graph.resolve(MockRegistry(self.dimensions, {}))
1685 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1686 a_i = graph.tasks["a"].inputs["i"]
1687 b_i = graph.tasks["b"].inputs["i"]
1688 self.assertEqual(b_i.dataset_type_name, "d.schema")
1689 self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1690 self.assertEqual(
1691 b_i.adapt_dataset_type(parent_dataset_type),
1692 parent_dataset_type.makeComponentDatasetType("schema"),
1693 )
1694 data_id = DataCoordinate.make_empty(self.dimensions)
1695 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1696 a_ref = a_i.adapt_dataset_ref(ref)
1697 b_ref = b_i.adapt_dataset_ref(ref)
1698 self.assertEqual(a_ref, ref)
1699 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1700 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1701 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1703 @unittest.skipUnless(
1704 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1705 )
1706 def test_component_resolved_by_output(self) -> None:
1707 """Test successful resolution of a component dataset type due to
1708 an output connection referencing the parent dataset type.
1709 """
1710 self.a_config.outputs["o"] = DynamicConnectionConfig(
1711 dataset_type_name="d", storage_class="ArrowTable"
1712 )
1713 self.b_config.inputs["i"] = DynamicConnectionConfig(
1714 dataset_type_name="d.schema", storage_class="ArrowSchema"
1715 )
1716 graph = self.make_graph()
1717 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1718 graph.resolve(MockRegistry(self.dimensions, {}))
1719 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1720 a_o = graph.tasks["a"].outputs["o"]
1721 b_i = graph.tasks["b"].inputs["i"]
1722 self.assertEqual(b_i.dataset_type_name, "d.schema")
1723 self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1724 self.assertEqual(
1725 b_i.adapt_dataset_type(parent_dataset_type),
1726 parent_dataset_type.makeComponentDatasetType("schema"),
1727 )
1728 data_id = DataCoordinate.make_empty(self.dimensions)
1729 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1730 a_ref = a_o.adapt_dataset_ref(ref)
1731 b_ref = b_i.adapt_dataset_ref(ref)
1732 self.assertEqual(a_ref, ref)
1733 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1734 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1735 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1737 @unittest.skipUnless(
1738 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1739 )
1740 def test_component_storage_class_converted(self) -> None:
1741 """Test successful resolution of a component dataset type due to
1742 an output connection referencing the parent dataset type, but with a
1743 different (convertible) storage class.
1744 """
1745 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="DataFrame")
1746 self.b_config.inputs["i"] = DynamicConnectionConfig(
1747 dataset_type_name="d.schema", storage_class="ArrowSchema"
1748 )
1749 graph = self.make_graph()
1750 output_parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1751 graph.resolve(MockRegistry(self.dimensions, {}))
1752 self.assertEqual(graph.dataset_types["d"].dataset_type, output_parent_dataset_type)
1753 a_o = graph.tasks["a"].outputs["o"]
1754 b_i = graph.tasks["b"].inputs["i"]
1755 self.assertEqual(b_i.dataset_type_name, "d.schema")
1756 self.assertEqual(a_o.adapt_dataset_type(output_parent_dataset_type), output_parent_dataset_type)
1757 self.assertEqual(
1758 # We don't really want to compare the full dataset type here,
1759 # because that's going to include a parentStorageClass that may or
1760 # may not make sense.
1761 b_i.adapt_dataset_type(output_parent_dataset_type).storageClass_name,
1762 get_mock_name("ArrowSchema"),
1763 )
1764 data_id = DataCoordinate.make_empty(self.dimensions)
1765 ref = DatasetRef(output_parent_dataset_type, data_id, run="r")
1766 a_ref = a_o.adapt_dataset_ref(ref)
1767 b_ref = b_i.adapt_dataset_ref(ref)
1768 self.assertEqual(a_ref, ref)
1769 self.assertEqual(b_ref.datasetType.storageClass_name, get_mock_name("ArrowSchema"))
1770 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1771 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1773 @unittest.skipUnless(
1774 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1775 )
1776 def test_component_resolved_by_registry(self) -> None:
1777 """Test successful resolution of a component dataset type due to
1778 the parent dataset type already being registered.
1779 """
1780 self.b_config.inputs["i"] = DynamicConnectionConfig(
1781 dataset_type_name="d.schema", storage_class="ArrowSchema"
1782 )
1783 graph = self.make_graph()
1784 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1785 graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type}))
1786 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1787 b_i = graph.tasks["b"].inputs["i"]
1788 self.assertEqual(b_i.dataset_type_name, "d.schema")
1789 self.assertEqual(
1790 b_i.adapt_dataset_type(parent_dataset_type),
1791 parent_dataset_type.makeComponentDatasetType("schema"),
1792 )
1793 data_id = DataCoordinate.make_empty(self.dimensions)
1794 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1795 b_ref = b_i.adapt_dataset_ref(ref)
1796 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1797 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1799 def test_optional_input(self) -> None:
1800 """Test that regular Input connections with minimum=0 result in
1801 dataset type nodes that are no initial query constraints.
1802 """
1803 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", minimum=0)
1804 graph = self.make_graph()
1805 graph.resolve(MockRegistry(self.dimensions, {}))
1806 self.assertFalse(graph.dataset_types["d"].is_initial_query_constraint)
1808 def test_invalid_dimensions(self) -> None:
1809 """Test that a connection with an invalid dimensions raises an
1810 exception (from butler) with the connection name information included.
1811 """
1812 self.a_config.outputs["o"] = DynamicConnectionConfig(
1813 dataset_type_name="d", dimensions=["frog"], storage_class="StructuredDataList"
1814 )
1815 graph = self.make_graph()
1816 with self.assertRaises(Exception) as error:
1817 graph.resolve(MockRegistry(self.dimensions, {}))
1818 self.assertEqual(error.exception.__notes__, ["In connection 'o' of task 'a'."])
1820 def test_invalid_dataset_type_name(self) -> None:
1821 """Test that a connection with an invalid dataset type name raises an
1822 exception (from butler) with the connection name information included.
1823 """
1824 self.a_config.outputs["o"] = DynamicConnectionConfig(
1825 dataset_type_name=":?", storage_class="StructuredDataList"
1826 )
1827 graph = self.make_graph()
1828 with self.assertRaises(Exception) as error:
1829 graph.resolve(MockRegistry(self.dimensions, {}))
1830 self.assertEqual(error.exception.__notes__, ["In connection 'o' of task 'a'."])
1833if __name__ == "__main__":
1834 lsst.utils.tests.init()
1835 unittest.main()