Coverage for tests/test_pipeline_graph.py: 16%
523 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-28 11:05 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-28 11:05 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests of things related to the GraphBuilder class."""
30import copy
31import io
32import logging
33import textwrap
34import unittest
35from typing import Any
37import lsst.pipe.base.automatic_connection_constants as acc
38import lsst.utils.tests
39from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory
40from lsst.daf.butler.registry import MissingDatasetTypeError
41from lsst.pipe.base.pipeline_graph import (
42 ConnectionTypeConsistencyError,
43 DuplicateOutputError,
44 Edge,
45 EdgesChangedError,
46 IncompatibleDatasetTypeError,
47 NodeKey,
48 NodeType,
49 PipelineGraph,
50 PipelineGraphError,
51 TaskImportMode,
52 UnresolvedGraphError,
53 visualization,
54)
55from lsst.pipe.base.tests.mocks import (
56 DynamicConnectionConfig,
57 DynamicTestPipelineTask,
58 DynamicTestPipelineTaskConfig,
59 get_mock_name,
60)
62_LOG = logging.getLogger(__name__)
65class MockRegistry:
66 """A test-utility stand-in for lsst.daf.butler.Registry that just knows
67 how to get dataset types.
68 """
70 def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None:
71 self.dimensions = dimensions
72 self._dataset_types = dataset_types
74 def getDatasetType(self, name: str) -> DatasetType:
75 try:
76 return self._dataset_types[name]
77 except KeyError:
78 raise MissingDatasetTypeError(name) from None
81class PipelineGraphTestCase(unittest.TestCase):
82 """Tests for the `PipelineGraph` class.
84 Tests for `PipelineGraph.resolve` are mostly in
85 `PipelineGraphResolveTestCase` later in this file.
86 """
88 def setUp(self) -> None:
89 # Simple test pipeline has two tasks, 'a' and 'b', with dataset types
90 # 'input', 'intermediate', and 'output'. There are no dimensions on
91 # any of those. We add tasks in reverse order to better test sorting.
92 # There is one labeled task subset, 'only_b', with just 'b' in it.
93 # We copy the configs so the originals (the instance attributes) can
94 # be modified and reused after the ones passed in to the graph are
95 # frozen.
96 self.description = "A pipeline for PipelineGraph unit tests."
97 self.graph = PipelineGraph()
98 self.graph.description = self.description
99 self.b_config = DynamicTestPipelineTaskConfig()
100 self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
101 self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
102 self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1")
103 self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config))
104 self.a_config = DynamicTestPipelineTaskConfig()
105 self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
106 self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1")
107 self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
108 self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config))
109 self.graph.add_task_subset("only_b", ["b"])
110 self.subset_description = "A subset with only task B in it."
111 self.graph.task_subsets["only_b"].description = self.subset_description
112 self.dimensions = DimensionUniverse()
113 self.maxDiff = None
115 def test_unresolved_accessors(self) -> None:
116 """Test attribute accessors, iteration, and simple methods on a graph
117 that has not had `PipelineGraph.resolve` called on it.
118 """
119 self.check_base_accessors(self.graph)
120 self.assertEqual(
121 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)"
122 )
124 def test_sorting(self) -> None:
125 """Test sort methods on PipelineGraph."""
126 self.assertFalse(self.graph.has_been_sorted)
127 self.assertFalse(self.graph.is_sorted)
128 self.graph.sort()
129 self.check_sorted(self.graph)
131 def test_unresolved_xgraph_export(self) -> None:
132 """Test exporting an unresolved PipelineGraph to networkx in various
133 ways.
134 """
135 self.check_make_xgraph(self.graph, resolved=False)
136 self.check_make_bipartite_xgraph(self.graph, resolved=False)
137 self.check_make_task_xgraph(self.graph, resolved=False)
138 self.check_make_dataset_type_xgraph(self.graph, resolved=False)
140 def test_unresolved_stream_io(self) -> None:
141 """Test round-tripping an unresolved PipelineGraph through in-memory
142 serialization.
143 """
144 stream = io.BytesIO()
145 self.graph._write_stream(stream)
146 stream.seek(0)
147 roundtripped = PipelineGraph._read_stream(stream)
148 self.check_make_xgraph(roundtripped, resolved=False)
150 def test_unresolved_file_io(self) -> None:
151 """Test round-tripping an unresolved PipelineGraph through file
152 serialization.
153 """
154 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
155 self.graph._write_uri(filename)
156 roundtripped = PipelineGraph._read_uri(filename)
157 self.check_make_xgraph(roundtripped, resolved=False)
159 def test_unresolved_deferred_import_io(self) -> None:
160 """Test round-tripping an unresolved PipelineGraph through
161 serialization, without immediately importing tasks on read.
162 """
163 stream = io.BytesIO()
164 self.graph._write_stream(stream)
165 stream.seek(0)
166 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
167 self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False)
168 # Check that we can still resolve the graph without importing tasks.
169 roundtripped.resolve(MockRegistry(self.dimensions, {}))
170 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
171 roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES)
172 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
174 def test_resolved_accessors(self) -> None:
175 """Test attribute accessors, iteration, and simple methods on a graph
176 that has had `PipelineGraph.resolve` called on it.
178 This includes the accessors available on unresolved graphs as well as
179 new ones, and we expect the resolved graph to be sorted as well.
180 """
181 self.graph.resolve(MockRegistry(self.dimensions, {}))
182 self.check_base_accessors(self.graph)
183 self.check_sorted(self.graph)
184 self.assertEqual(
185 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})"
186 )
187 self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty)
188 self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})")
189 self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1"))
190 self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty)
191 self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict")
192 self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict")
194 def test_resolved_xgraph_export(self) -> None:
195 """Test exporting a resolved PipelineGraph to networkx in various
196 ways.
197 """
198 self.graph.resolve(MockRegistry(self.dimensions, {}))
199 self.check_make_xgraph(self.graph, resolved=True)
200 self.check_make_bipartite_xgraph(self.graph, resolved=True)
201 self.check_make_task_xgraph(self.graph, resolved=True)
202 self.check_make_dataset_type_xgraph(self.graph, resolved=True)
204 def test_resolved_stream_io(self) -> None:
205 """Test round-tripping a resolved PipelineGraph through in-memory
206 serialization.
207 """
208 self.graph.resolve(MockRegistry(self.dimensions, {}))
209 stream = io.BytesIO()
210 self.graph._write_stream(stream)
211 stream.seek(0)
212 roundtripped = PipelineGraph._read_stream(stream)
213 self.check_make_xgraph(roundtripped, resolved=True)
215 def test_resolved_file_io(self) -> None:
216 """Test round-tripping a resolved PipelineGraph through file
217 serialization.
218 """
219 self.graph.resolve(MockRegistry(self.dimensions, {}))
220 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
221 self.graph._write_uri(filename)
222 roundtripped = PipelineGraph._read_uri(filename)
223 self.check_make_xgraph(roundtripped, resolved=True)
225 def test_resolved_deferred_import_io(self) -> None:
226 """Test round-tripping a resolved PipelineGraph through serialization,
227 without immediately importing tasks on read.
228 """
229 self.graph.resolve(MockRegistry(self.dimensions, {}))
230 stream = io.BytesIO()
231 self.graph._write_stream(stream)
232 stream.seek(0)
233 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
234 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
235 roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES)
236 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
238 def test_unresolved_copies(self) -> None:
239 """Test making copies of an unresolved PipelineGraph."""
240 copy1 = self.graph.copy()
241 self.assertIsNot(copy1, self.graph)
242 self.check_make_xgraph(copy1, resolved=False)
243 copy2 = copy.copy(self.graph)
244 self.assertIsNot(copy2, self.graph)
245 self.check_make_xgraph(copy2, resolved=False)
246 copy3 = copy.deepcopy(self.graph)
247 self.assertIsNot(copy3, self.graph)
248 self.check_make_xgraph(copy3, resolved=False)
250 def test_resolved_copies(self) -> None:
251 """Test making copies of a resolved PipelineGraph."""
252 self.graph.resolve(MockRegistry(self.dimensions, {}))
253 copy1 = self.graph.copy()
254 self.assertIsNot(copy1, self.graph)
255 self.check_make_xgraph(copy1, resolved=True)
256 copy2 = copy.copy(self.graph)
257 self.assertIsNot(copy2, self.graph)
258 self.check_make_xgraph(copy2, resolved=True)
259 copy3 = copy.deepcopy(self.graph)
260 self.assertIsNot(copy3, self.graph)
261 self.check_make_xgraph(copy3, resolved=True)
263 def check_base_accessors(self, graph: PipelineGraph) -> None:
264 """Run parameterized tests that check attribute access, iteration, and
265 simple methods.
267 The given graph must be unchanged from the one defined in `setUp`,
268 other than sorting.
269 """
270 self.assertEqual(graph.description, self.description)
271 self.assertEqual(graph.tasks.keys(), {"a", "b"})
272 self.assertEqual(
273 graph.dataset_types.keys(),
274 {
275 "schema",
276 "input_1",
277 "intermediate_1",
278 "output_1",
279 "a_config",
280 "a_log",
281 "a_metadata",
282 "b_config",
283 "b_log",
284 "b_metadata",
285 },
286 )
287 self.assertEqual(graph.task_subsets.keys(), {"only_b"})
288 self.assertEqual(
289 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)},
290 {
291 (
292 NodeKey(NodeType.DATASET_TYPE, "input_1"),
293 NodeKey(NodeType.TASK, "a"),
294 "input_1 -> a (input1)",
295 ),
296 (
297 NodeKey(NodeType.TASK, "a"),
298 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
299 "a -> intermediate_1 (output1)",
300 ),
301 (
302 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
303 NodeKey(NodeType.TASK, "b"),
304 "intermediate_1 -> b (input1)",
305 ),
306 (
307 NodeKey(NodeType.TASK, "b"),
308 NodeKey(NodeType.DATASET_TYPE, "output_1"),
309 "b -> output_1 (output1)",
310 ),
311 (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"),
312 (
313 NodeKey(NodeType.TASK, "a"),
314 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
315 "a -> a_metadata (_metadata)",
316 ),
317 (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"),
318 (
319 NodeKey(NodeType.TASK, "b"),
320 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
321 "b -> b_metadata (_metadata)",
322 ),
323 },
324 )
325 self.assertEqual(
326 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)},
327 {
328 (
329 NodeKey(NodeType.TASK_INIT, "a"),
330 NodeKey(NodeType.DATASET_TYPE, "schema"),
331 "a -> schema (out_schema)",
332 ),
333 (
334 NodeKey(NodeType.DATASET_TYPE, "schema"),
335 NodeKey(NodeType.TASK_INIT, "b"),
336 "schema -> b (in_schema)",
337 ),
338 (
339 NodeKey(NodeType.TASK_INIT, "a"),
340 NodeKey(NodeType.DATASET_TYPE, "a_config"),
341 "a -> a_config (_config)",
342 ),
343 (
344 NodeKey(NodeType.TASK_INIT, "b"),
345 NodeKey(NodeType.DATASET_TYPE, "b_config"),
346 "b -> b_config (_config)",
347 ),
348 },
349 )
350 self.assertEqual(
351 {(node_type, name) for node_type, name, _ in graph.iter_nodes()},
352 {
353 NodeKey(NodeType.TASK, "a"),
354 NodeKey(NodeType.TASK, "b"),
355 NodeKey(NodeType.TASK_INIT, "a"),
356 NodeKey(NodeType.TASK_INIT, "b"),
357 NodeKey(NodeType.DATASET_TYPE, "schema"),
358 NodeKey(NodeType.DATASET_TYPE, "input_1"),
359 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
360 NodeKey(NodeType.DATASET_TYPE, "output_1"),
361 NodeKey(NodeType.DATASET_TYPE, "a_config"),
362 NodeKey(NodeType.DATASET_TYPE, "a_log"),
363 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
364 NodeKey(NodeType.DATASET_TYPE, "b_config"),
365 NodeKey(NodeType.DATASET_TYPE, "b_log"),
366 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
367 },
368 )
369 self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"})
370 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"})
371 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"})
372 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set())
373 self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"})
374 self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"})
375 self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set())
377 self.assertIsNone(graph.producing_edge_of("input_1"))
378 self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a")
379 self.assertEqual(graph.producing_edge_of("output_1").task_label, "b")
380 self.assertIsNone(graph.producer_of("input_1"))
381 self.assertEqual(graph.producer_of("intermediate_1").label, "a")
382 self.assertEqual(graph.producer_of("output_1").label, "b")
384 self.assertEqual(graph.inputs_of("a").keys(), {"input_1"})
385 self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"})
386 self.assertEqual(graph.inputs_of("a", init=True).keys(), set())
387 self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"})
388 self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"})
389 self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"})
390 self.assertEqual(
391 graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"}
392 )
393 self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"})
394 self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"})
395 self.assertEqual(
396 graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"}
397 )
398 self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"})
399 self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set())
401 self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks="))
402 self.assertEqual(
403 repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}"
404 )
406 def check_sorted(self, graph: PipelineGraph) -> None:
407 """Run a battery of tests on a PipelineGraph that must be
408 deterministically sorted.
410 The given graph must be unchanged from the one defined in `setUp`,
411 other than sorting.
412 """
413 self.assertTrue(graph.has_been_sorted)
414 self.assertTrue(graph.is_sorted)
415 self.assertEqual(
416 [(node_type, name) for node_type, name, _ in graph.iter_nodes()],
417 [
418 # We only advertise that the order is topological and
419 # deterministic, so this test is slightly over-specified; there
420 # are other orders that are consistent with our guarantees.
421 NodeKey(NodeType.DATASET_TYPE, "input_1"),
422 NodeKey(NodeType.TASK_INIT, "a"),
423 NodeKey(NodeType.DATASET_TYPE, "a_config"),
424 NodeKey(NodeType.DATASET_TYPE, "schema"),
425 NodeKey(NodeType.TASK_INIT, "b"),
426 NodeKey(NodeType.DATASET_TYPE, "b_config"),
427 NodeKey(NodeType.TASK, "a"),
428 NodeKey(NodeType.DATASET_TYPE, "a_log"),
429 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
430 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
431 NodeKey(NodeType.TASK, "b"),
432 NodeKey(NodeType.DATASET_TYPE, "b_log"),
433 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
434 NodeKey(NodeType.DATASET_TYPE, "output_1"),
435 ],
436 )
437 # Most users should only care that the tasks and dataset types are
438 # topologically sorted.
439 self.assertEqual(list(graph.tasks), ["a", "b"])
440 self.assertEqual(
441 list(graph.dataset_types),
442 [
443 "input_1",
444 "a_config",
445 "schema",
446 "b_config",
447 "a_log",
448 "a_metadata",
449 "intermediate_1",
450 "b_log",
451 "b_metadata",
452 "output_1",
453 ],
454 )
455 # __str__ and __repr__ of course work on unsorted mapping views, too,
456 # but the order of elements is then nondeterministic and hard to test.
457 self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})")
458 self.assertEqual(
459 repr(self.graph.dataset_types),
460 (
461 "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, "
462 "intermediate_1, b_log, b_metadata, output_1})"
463 ),
464 )
466 def check_make_xgraph(
467 self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True
468 ) -> None:
469 """Check that the given graph exports as expected to networkx.
471 The given graph must be unchanged from the one defined in `setUp`,
472 other than being resolved (if ``resolved=True``) or round-tripped
473 through serialization without tasks being imported (if
474 ``imported_and_configured=False``).
475 """
476 xgraph = graph.make_xgraph()
477 expected_edges = (
478 {edge.key for edge in graph.iter_edges()}
479 | {edge.key for edge in graph.iter_edges(init=True)}
480 | {
481 (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME),
482 (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME),
483 }
484 )
485 test_edges = set(xgraph.edges)
486 self.assertEqual(test_edges, expected_edges)
487 expected_nodes = {
488 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node(
489 "a", resolved, imported_and_configured=imported_and_configured
490 ),
491 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node(
492 "a", resolved, imported_and_configured=imported_and_configured
493 ),
494 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node(
495 "b", resolved, imported_and_configured=imported_and_configured
496 ),
497 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node(
498 "b", resolved, imported_and_configured=imported_and_configured
499 ),
500 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
501 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
502 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
503 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
504 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
505 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
506 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
507 "schema", resolved, is_initial_query_constraint=False
508 ),
509 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
510 "input_1", resolved, is_initial_query_constraint=True
511 ),
512 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
513 "intermediate_1", resolved, is_initial_query_constraint=False
514 ),
515 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
516 "output_1", resolved, is_initial_query_constraint=False
517 ),
518 }
519 test_nodes = dict(xgraph.nodes.items())
520 self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys()))
521 for key, expected_node in expected_nodes.items():
522 test_node = test_nodes[key]
523 self.assertEqual(expected_node, test_node, key)
525 def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
526 """Check that the given graph's init-only or runtime subset exports as
527 expected to networkx.
529 The given graph must be unchanged from the one defined in `setUp`,
530 other than being resolved (if ``resolved=True``).
531 """
532 run_xgraph = graph.make_bipartite_xgraph()
533 self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()})
534 self.assertEqual(
535 dict(run_xgraph.nodes.items()),
536 {
537 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
538 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
539 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
540 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
541 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
542 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
543 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
544 "input_1", resolved, is_initial_query_constraint=True
545 ),
546 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
547 "intermediate_1", resolved, is_initial_query_constraint=False
548 ),
549 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
550 "output_1", resolved, is_initial_query_constraint=False
551 ),
552 },
553 )
554 init_xgraph = graph.make_bipartite_xgraph(
555 init=True,
556 )
557 self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)})
558 self.assertEqual(
559 dict(init_xgraph.nodes.items()),
560 {
561 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
562 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
563 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
564 "schema", resolved, is_initial_query_constraint=False
565 ),
566 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
567 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
568 },
569 )
571 def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
572 """Check that the given graph's task-only projection exports as
573 expected to networkx.
575 The given graph must be unchanged from the one defined in `setUp`,
576 other than being resolved (if ``resolved=True``).
577 """
578 run_xgraph = graph.make_task_xgraph()
579 self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))})
580 self.assertEqual(
581 dict(run_xgraph.nodes.items()),
582 {
583 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
584 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
585 },
586 )
587 init_xgraph = graph.make_task_xgraph(
588 init=True,
589 )
590 self.assertEqual(
591 set(init_xgraph.edges),
592 {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))},
593 )
594 self.assertEqual(
595 dict(init_xgraph.nodes.items()),
596 {
597 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
598 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
599 },
600 )
602 def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
603 """Check that the given graph's dataset-type-only projection exports as
604 expected to networkx.
606 The given graph must be unchanged from the one defined in `setUp`,
607 other than being resolved (if ``resolved=True``).
608 """
609 run_xgraph = graph.make_dataset_type_xgraph()
610 self.assertEqual(
611 set(run_xgraph.edges),
612 {
613 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")),
614 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")),
615 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")),
616 (
617 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
618 NodeKey(NodeType.DATASET_TYPE, "output_1"),
619 ),
620 (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")),
621 (
622 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
623 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
624 ),
625 },
626 )
627 self.assertEqual(
628 dict(run_xgraph.nodes.items()),
629 {
630 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
631 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
632 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
633 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
634 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
635 "input_1", resolved, is_initial_query_constraint=True
636 ),
637 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
638 "intermediate_1", resolved, is_initial_query_constraint=False
639 ),
640 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
641 "output_1", resolved, is_initial_query_constraint=False
642 ),
643 },
644 )
645 init_xgraph = graph.make_dataset_type_xgraph(init=True)
646 self.assertEqual(
647 set(init_xgraph.edges),
648 {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))},
649 )
650 self.assertEqual(
651 dict(init_xgraph.nodes.items()),
652 {
653 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
654 "schema", resolved, is_initial_query_constraint=False
655 ),
656 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
657 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
658 },
659 )
661 def get_expected_task_node(
662 self, label: str, resolved: bool, imported_and_configured: bool = True
663 ) -> dict[str, Any]:
664 """Construct a networkx-export task node for comparison."""
665 result = self.get_expected_task_init_node(
666 label, resolved, imported_and_configured=imported_and_configured
667 )
668 if resolved:
669 result["dimensions"] = self.dimensions.empty
670 result["raw_dimensions"] = frozenset()
671 return result
673 def get_expected_task_init_node(
674 self, label: str, resolved: bool, imported_and_configured: bool = True
675 ) -> dict[str, Any]:
676 """Construct a networkx-export task init for comparison."""
677 result = {
678 "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask",
679 "bipartite": 1,
680 }
681 if imported_and_configured:
682 result["task_class"] = DynamicTestPipelineTask
683 result["config"] = getattr(self, f"{label}_config")
684 return result
686 def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]:
687 """Construct a networkx-export init-output config dataset type node for
688 comparison.
689 """
690 if not resolved:
691 return {"bipartite": 0}
692 else:
693 return {
694 "dataset_type": DatasetType(
695 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
696 self.dimensions.empty,
697 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
698 ),
699 "is_initial_query_constraint": False,
700 "is_prerequisite": False,
701 "dimensions": self.dimensions.empty,
702 "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
703 "bipartite": 0,
704 }
706 def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]:
707 """Construct a networkx-export output log dataset type node for
708 comparison.
709 """
710 if not resolved:
711 return {"bipartite": 0}
712 else:
713 return {
714 "dataset_type": DatasetType(
715 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
716 self.dimensions.empty,
717 acc.LOG_OUTPUT_STORAGE_CLASS,
718 ),
719 "is_initial_query_constraint": False,
720 "is_prerequisite": False,
721 "dimensions": self.dimensions.empty,
722 "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS,
723 "bipartite": 0,
724 }
726 def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]:
727 """Construct a networkx-export output metadata dataset type node for
728 comparison.
729 """
730 if not resolved:
731 return {"bipartite": 0}
732 else:
733 return {
734 "dataset_type": DatasetType(
735 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
736 self.dimensions.empty,
737 acc.METADATA_OUTPUT_STORAGE_CLASS,
738 ),
739 "is_initial_query_constraint": False,
740 "is_prerequisite": False,
741 "dimensions": self.dimensions.empty,
742 "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS,
743 "bipartite": 0,
744 }
746 def get_expected_connection_node(
747 self, name: str, resolved: bool, *, is_initial_query_constraint: bool
748 ) -> dict[str, Any]:
749 """Construct a networkx-export dataset type node for comparison."""
750 if not resolved:
751 return {"bipartite": 0}
752 else:
753 return {
754 "dataset_type": DatasetType(
755 name,
756 self.dimensions.empty,
757 get_mock_name("StructuredDataDict"),
758 ),
759 "is_initial_query_constraint": is_initial_query_constraint,
760 "is_prerequisite": False,
761 "dimensions": self.dimensions.empty,
762 "storage_class_name": get_mock_name("StructuredDataDict"),
763 "bipartite": 0,
764 }
766 def test_construct_with_data_coordinate(self) -> None:
767 """Test constructing a graph with a DataCoordinate.
769 Since this creates a graph with DimensionUniverse, all tasks added to
770 it should have resolved dimensions, but not (yet) resolved dataset
771 types. We use that to test a few other operations in that state.
772 """
773 data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions)
774 graph = PipelineGraph(data_id=data_id)
775 self.assertEqual(graph.universe, self.dimensions)
776 self.assertEqual(graph.data_id, data_id)
777 graph.add_task("b1", DynamicTestPipelineTask, self.b_config)
778 self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty)
779 # Still can't group by dimensions, because the dataset types aren't
780 # resolved.
781 with self.assertRaises(UnresolvedGraphError):
782 graph.group_by_dimensions()
783 # Transferring a node from this graph to ``self.graph`` should
784 # unresolve the dimensions.
785 self.graph.add_task_nodes([graph.tasks["b1"]])
786 self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"])
787 self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions)
788 # Do the opposite transfer, which should resolve dimensions.
789 graph.add_task_nodes([self.graph.tasks["a"]])
790 self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"])
791 self.assertTrue(graph.tasks["a"].has_resolved_dimensions)
793 def test_group_by_dimensions(self) -> None:
794 """Test PipelineGraph.group_by_dimensions."""
795 with self.assertRaises(UnresolvedGraphError):
796 self.graph.group_by_dimensions()
797 self.a_config.dimensions = ["visit"]
798 self.a_config.outputs["output1"].dimensions = ["visit"]
799 self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig(
800 dataset_type_name="prereq_1",
801 multiple=True,
802 dimensions=["htm7"],
803 is_calibration=True,
804 )
805 self.b_config.dimensions = ["htm7"]
806 self.b_config.inputs["input1"].dimensions = ["visit"]
807 self.b_config.inputs["input1"].multiple = True
808 self.b_config.outputs["output1"].dimensions = ["htm7"]
809 self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config)
810 self.graph.resolve(MockRegistry(self.dimensions, {}))
811 visit_dims = self.dimensions.conform(["visit"])
812 htm7_dims = self.dimensions.conform(["htm7"])
813 expected = {
814 self.dimensions.empty.as_group(): (
815 {},
816 {
817 "schema": self.graph.dataset_types["schema"],
818 "input_1": self.graph.dataset_types["input_1"],
819 "a_config": self.graph.dataset_types["a_config"],
820 "b_config": self.graph.dataset_types["b_config"],
821 },
822 ),
823 visit_dims: (
824 {"a": self.graph.tasks["a"]},
825 {
826 "a_log": self.graph.dataset_types["a_log"],
827 "a_metadata": self.graph.dataset_types["a_metadata"],
828 "intermediate_1": self.graph.dataset_types["intermediate_1"],
829 },
830 ),
831 htm7_dims: (
832 {"b": self.graph.tasks["b"]},
833 {
834 "b_log": self.graph.dataset_types["b_log"],
835 "b_metadata": self.graph.dataset_types["b_metadata"],
836 "output_1": self.graph.dataset_types["output_1"],
837 },
838 ),
839 }
840 self.assertEqual(self.graph.group_by_dimensions(), expected)
841 expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"]
842 self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected)
844 def test_add_and_remove(self) -> None:
845 """Tests for adding and removing tasks and task subsets from a
846 PipelineGraph.
847 """
848 # Can't remove a task while it's still in a subset.
849 with self.assertRaises(PipelineGraphError):
850 self.graph.remove_tasks(["b"], drop_from_subsets=False)
851 # ...unless you remove the subset.
852 self.graph.remove_task_subset("only_b")
853 self.assertFalse(self.graph.task_subsets)
854 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False)
855 self.assertFalse(referencing_subsets)
856 self.assertEqual(self.graph.tasks.keys(), {"a"})
857 # Add that task back in.
858 self.graph.add_task_nodes([b])
859 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
860 # Add the subset back in.
861 self.graph.add_task_subset("only_b", {"b"})
862 self.assertEqual(self.graph.task_subsets.keys(), {"only_b"})
863 # Resolve the graph's dataset types and task dimensions.
864 self.graph.resolve(MockRegistry(self.dimensions, {}))
865 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
866 self.assertTrue(self.graph.dataset_types.is_resolved("output_1"))
867 self.assertTrue(self.graph.dataset_types.is_resolved("schema"))
868 self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1"))
869 # Remove the task while removing it from the subset automatically. This
870 # should also unresolve (only) the referenced dataset types and drop
871 # any datasets no longer attached to any task.
872 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
873 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True)
874 self.assertEqual(referencing_subsets, {"only_b"})
875 self.assertEqual(self.graph.tasks.keys(), {"a"})
876 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
877 self.assertNotIn("output1", self.graph.dataset_types)
878 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
879 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
881 def test_reconfigure(self) -> None:
882 """Tests for PipelineGraph.reconfigure."""
883 self.graph.resolve(MockRegistry(self.dimensions, {}))
884 self.b_config.outputs["output1"].storage_class = "TaskMetadata"
885 with self.assertRaises(ValueError):
886 # Can't check and assume together.
887 self.graph.reconfigure_tasks(
888 b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True
889 )
890 # Check that graph is unchanged after error.
891 self.check_base_accessors(self.graph)
892 with self.assertRaises(EdgesChangedError):
893 self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True)
894 self.check_base_accessors(self.graph)
895 # Make a change that does affect edges; this will unresolve most
896 # dataset types.
897 self.graph.reconfigure_tasks(b=self.b_config)
898 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
899 self.assertFalse(self.graph.dataset_types.is_resolved("output_1"))
900 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
901 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
902 # Resolving again will pick up the new storage class
903 self.graph.resolve(MockRegistry(self.dimensions, {}))
904 self.assertEqual(
905 self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata")
906 )
908 def check_visualization(self, graph: PipelineGraph, expected: str, **kwargs: Any) -> None:
909 """Run pipeline graph visualization with the given kwargs and check
910 that the output is the given expected string.
912 Parameters
913 ----------
914 graph : `lsst.pipe.base.pipeline_graph.PipelineGraph`
915 Pipeline graph to visualize.
916 expected : `str`
917 Expected output of the visualization. Will be passed through
918 `textwrap.dedent`, to allow it to be written with triple-quotes.
919 **kwargs
920 Forwarded to `lsst.pipe.base.pipeline_graph.visualization.show`.
921 """
922 stream = io.StringIO()
923 visualization.show(graph, stream, **kwargs)
924 self.assertEqual(textwrap.dedent(expected), stream.getvalue())
926 def test_unresolved_visualization(self) -> None:
927 """Test pipeline graph text-based visualization on unresolved
928 graphs.
929 """
930 self.check_visualization(
931 self.graph,
932 """
933 ■ a
934 │
935 ■ b
936 """,
937 merge_input_trees=0,
938 merge_output_trees=0,
939 merge_intermediates=False,
940 )
941 self.check_visualization(
942 self.graph,
943 """
944 ○ input_1
945 │
946 ■ a
947 │
948 ○ intermediate_1
949 │
950 ■ b
951 │
952 ○ output_1
953 """,
954 dataset_types=True,
955 )
957 def test_resolved_visualization(self) -> None:
958 """Test pipeline graph text-based visualization on resolved graphs."""
959 self.graph.resolve(MockRegistry(dimensions=self.dimensions, dataset_types={}))
960 self.check_visualization(
961 self.graph,
962 """
963 ■ a: {} DynamicTestPipelineTask
964 │
965 ■ b: {} DynamicTestPipelineTask
966 """,
967 task_classes="concise",
968 merge_input_trees=0,
969 merge_output_trees=0,
970 merge_intermediates=False,
971 )
972 self.check_visualization(
973 self.graph,
974 """
975 ○ input_1: {} _mock_StructuredDataDict
976 │
977 ■ a: {} lsst.pipe.base.tests.mocks.DynamicTestPipelineTask
978 │
979 ○ intermediate_1: {} _mock_StructuredDataDict
980 │
981 ■ b: {} lsst.pipe.base.tests.mocks.DynamicTestPipelineTask
982 │
983 ○ output_1: {} _mock_StructuredDataDict
984 """,
985 task_classes="full",
986 dataset_types=True,
987 )
990def _have_example_storage_classes() -> bool:
991 """Check whether some storage classes work as expected.
993 Given that these have registered converters, it shouldn't actually be
994 necessary to import be able to those types in order to determine that
995 they're convertible, but the storage class machinery is implemented such
996 that types that can't be imported can't be converted, and while that's
997 inconvenient here it's totally fine in non-testing scenarios where you only
998 care about a storage class if you can actually use it.
999 """
1000 getter = StorageClassFactory().getStorageClass
1001 return (
1002 getter("ArrowTable").can_convert(getter("ArrowAstropy"))
1003 and getter("ArrowAstropy").can_convert(getter("ArrowTable"))
1004 and getter("ArrowTable").can_convert(getter("DataFrame"))
1005 and getter("DataFrame").can_convert(getter("ArrowTable"))
1006 )
1009class PipelineGraphResolveTestCase(unittest.TestCase):
1010 """More extensive tests for PipelineGraph.resolve and its primate helper
1011 methods.
1013 These are in a separate TestCase because they utilize a different `setUp`
1014 from the rest of the `PipelineGraph` tests.
1015 """
1017 def setUp(self) -> None:
1018 self.a_config = DynamicTestPipelineTaskConfig()
1019 self.b_config = DynamicTestPipelineTaskConfig()
1020 self.dimensions = DimensionUniverse()
1021 self.maxDiff = None
1023 def make_graph(self) -> PipelineGraph:
1024 graph = PipelineGraph()
1025 graph.add_task("a", DynamicTestPipelineTask, self.a_config)
1026 graph.add_task("b", DynamicTestPipelineTask, self.b_config)
1027 return graph
1029 def test_prerequisite_inconsistency(self) -> None:
1030 """Test that we raise an exception when one edge defines a dataset type
1031 as a prerequisite and another does not.
1033 This test will hopefully someday go away (along with
1034 `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation
1035 algorithm becomes more flexible.
1036 """
1037 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
1038 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
1039 graph = self.make_graph()
1040 with self.assertRaises(ConnectionTypeConsistencyError):
1041 graph.resolve(MockRegistry(self.dimensions, {}))
1043 def test_prerequisite_inconsistency_reversed(self) -> None:
1044 """Same as `test_prerequisite_inconsistency`, with the order the edges
1045 are added to the graph reversed.
1046 """
1047 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
1048 self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
1049 graph = self.make_graph()
1050 with self.assertRaises(ConnectionTypeConsistencyError):
1051 graph.resolve(MockRegistry(self.dimensions, {}))
1053 def test_prerequisite_output(self) -> None:
1054 """Test that we raise an exception when one edge defines a dataset type
1055 as a prerequisite but another defines it as an output.
1056 """
1057 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
1058 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1059 graph = self.make_graph()
1060 with self.assertRaises(ConnectionTypeConsistencyError):
1061 graph.resolve(MockRegistry(self.dimensions, {}))
1063 def test_skypix_missing(self) -> None:
1064 """Test that we raise an exception when one edge uses the "skypix"
1065 dimension as a placeholder but the dataset type is not registered.
1066 """
1067 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
1068 dataset_type_name="d", dimensions={"skypix"}
1069 )
1070 graph = self.make_graph()
1071 with self.assertRaises(MissingDatasetTypeError):
1072 graph.resolve(MockRegistry(self.dimensions, {}))
1074 def test_skypix_inconsistent(self) -> None:
1075 """Test that we raise an exception when one edge uses the "skypix"
1076 dimension as a placeholder but the rest of the dimensions are
1077 inconsistent with the registered dataset type.
1078 """
1079 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
1080 dataset_type_name="d", dimensions={"skypix", "visit"}
1081 )
1082 graph = self.make_graph()
1083 with self.assertRaises(IncompatibleDatasetTypeError):
1084 graph.resolve(
1085 MockRegistry(
1086 self.dimensions,
1087 {
1088 "d": DatasetType(
1089 "d",
1090 dimensions=self.dimensions.conform(["htm7"]),
1091 storageClass="StructuredDataDict",
1092 )
1093 },
1094 )
1095 )
1096 with self.assertRaises(IncompatibleDatasetTypeError):
1097 graph.resolve(
1098 MockRegistry(
1099 self.dimensions,
1100 {
1101 "d": DatasetType(
1102 "d",
1103 dimensions=self.dimensions.conform(["htm7", "visit", "skymap"]),
1104 storageClass="StructuredDataDict",
1105 )
1106 },
1107 )
1108 )
1110 def test_duplicate_outputs(self) -> None:
1111 """Test that we raise an exception when a dataset type node would have
1112 two write edges.
1113 """
1114 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1115 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1116 graph = self.make_graph()
1117 with self.assertRaises(DuplicateOutputError):
1118 graph.resolve(MockRegistry(self.dimensions, {}))
1120 def test_component_of_unregistered_parent(self) -> None:
1121 """Test that we raise an exception when a component dataset type's
1122 parent is not registered.
1123 """
1124 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1125 graph = self.make_graph()
1126 with self.assertRaises(MissingDatasetTypeError):
1127 graph.resolve(MockRegistry(self.dimensions, {}))
1129 def test_undefined_component(self) -> None:
1130 """Test that we raise an exception when a component dataset type's
1131 parent is registered, but its storage class does not have that
1132 component.
1133 """
1134 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1135 graph = self.make_graph()
1136 with self.assertRaises(IncompatibleDatasetTypeError):
1137 graph.resolve(
1138 MockRegistry(
1139 self.dimensions,
1140 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1141 )
1142 )
1144 @unittest.skipUnless(
1145 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1146 )
1147 def test_bad_component_storage_class(self) -> None:
1148 """Test that we raise an exception when a component dataset type's
1149 parent is registered, but does not have that component.
1150 """
1151 self.a_config.inputs["i"] = DynamicConnectionConfig(
1152 dataset_type_name="d.schema", storage_class="StructuredDataDict"
1153 )
1154 graph = self.make_graph()
1155 with self.assertRaises(IncompatibleDatasetTypeError):
1156 graph.resolve(
1157 MockRegistry(
1158 self.dimensions,
1159 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))},
1160 )
1161 )
1163 def test_input_storage_class_incompatible_with_registry(self) -> None:
1164 """Test that we raise an exception when an input connection's storage
1165 class is incompatible with the registry definition.
1166 """
1167 self.a_config.inputs["i"] = DynamicConnectionConfig(
1168 dataset_type_name="d", storage_class="StructuredDataList"
1169 )
1170 graph = self.make_graph()
1171 with self.assertRaises(IncompatibleDatasetTypeError):
1172 graph.resolve(
1173 MockRegistry(
1174 self.dimensions,
1175 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1176 )
1177 )
1179 def test_output_storage_class_incompatible_with_registry(self) -> None:
1180 """Test that we raise an exception when an output connection's storage
1181 class is incompatible with the registry definition.
1182 """
1183 self.a_config.outputs["o"] = DynamicConnectionConfig(
1184 dataset_type_name="d", storage_class="StructuredDataList"
1185 )
1186 graph = self.make_graph()
1187 with self.assertRaises(IncompatibleDatasetTypeError):
1188 graph.resolve(
1189 MockRegistry(
1190 self.dimensions,
1191 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1192 )
1193 )
1195 def test_input_storage_class_incompatible_with_output(self) -> None:
1196 """Test that we raise an exception when an input connection's storage
1197 class is incompatible with the storage class of the output connection.
1198 """
1199 self.a_config.outputs["o"] = DynamicConnectionConfig(
1200 dataset_type_name="d", storage_class="StructuredDataDict"
1201 )
1202 self.b_config.inputs["i"] = DynamicConnectionConfig(
1203 dataset_type_name="d", storage_class="StructuredDataList"
1204 )
1205 graph = self.make_graph()
1206 with self.assertRaises(IncompatibleDatasetTypeError):
1207 graph.resolve(MockRegistry(self.dimensions, {}))
1209 def test_ambiguous_storage_class(self) -> None:
1210 """Test that we raise an exception when two input connections define
1211 the same dataset with different storage classes (even compatible ones)
1212 and there is no output connection or registry definition to take
1213 precedence.
1214 """
1215 self.a_config.inputs["i"] = DynamicConnectionConfig(
1216 dataset_type_name="d", storage_class="StructuredDataDict"
1217 )
1218 self.b_config.inputs["i"] = DynamicConnectionConfig(
1219 dataset_type_name="d", storage_class="StructuredDataList"
1220 )
1221 graph = self.make_graph()
1222 with self.assertRaises(MissingDatasetTypeError):
1223 graph.resolve(MockRegistry(self.dimensions, {}))
1225 @unittest.skipUnless(
1226 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1227 )
1228 def test_inputs_compatible_with_registry(self) -> None:
1229 """Test successful resolution of a dataset type where input edges have
1230 different but compatible storage classes and the dataset type is
1231 already registered.
1232 """
1233 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1234 self.b_config.inputs["i"] = DynamicConnectionConfig(
1235 dataset_type_name="d", storage_class="ArrowAstropy"
1236 )
1237 graph = self.make_graph()
1238 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1239 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1240 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1241 a_i = graph.tasks["a"].inputs["i"]
1242 b_i = graph.tasks["b"].inputs["i"]
1243 self.assertEqual(
1244 a_i.adapt_dataset_type(dataset_type),
1245 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1246 )
1247 self.assertEqual(
1248 b_i.adapt_dataset_type(dataset_type),
1249 dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1250 )
1251 data_id = DataCoordinate.make_empty(self.dimensions)
1252 ref = DatasetRef(dataset_type, data_id, run="r")
1253 a_ref = a_i.adapt_dataset_ref(ref)
1254 b_ref = b_i.adapt_dataset_ref(ref)
1255 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1256 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1257 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1258 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1260 @unittest.skipUnless(
1261 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1262 )
1263 def test_output_compatible_with_registry(self) -> None:
1264 """Test successful resolution of a dataset type where an output edge
1265 has a different but compatible storage class from the dataset type
1266 already registered.
1267 """
1268 self.a_config.outputs["o"] = DynamicConnectionConfig(
1269 dataset_type_name="d", storage_class="ArrowTable"
1270 )
1271 graph = self.make_graph()
1272 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1273 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1274 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1275 a_o = graph.tasks["a"].outputs["o"]
1276 self.assertEqual(
1277 a_o.adapt_dataset_type(dataset_type),
1278 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1279 )
1280 data_id = DataCoordinate.make_empty(self.dimensions)
1281 ref = DatasetRef(dataset_type, data_id, run="r")
1282 a_ref = a_o.adapt_dataset_ref(ref)
1283 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1284 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1286 @unittest.skipUnless(
1287 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1288 )
1289 def test_inputs_compatible_with_output(self) -> None:
1290 """Test successful resolution of a dataset type where an input edge has
1291 a different but compatible storage class from the output edge, and
1292 the dataset type is not registered.
1293 """
1294 self.a_config.outputs["o"] = DynamicConnectionConfig(
1295 dataset_type_name="d", storage_class="ArrowTable"
1296 )
1297 self.b_config.inputs["i"] = DynamicConnectionConfig(
1298 dataset_type_name="d", storage_class="ArrowAstropy"
1299 )
1300 graph = self.make_graph()
1301 a_o = graph.tasks["a"].outputs["o"]
1302 b_i = graph.tasks["b"].inputs["i"]
1303 graph.resolve(MockRegistry(self.dimensions, {}))
1304 self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable"))
1305 self.assertEqual(
1306 a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1307 graph.dataset_types["d"].dataset_type,
1308 )
1309 self.assertEqual(
1310 b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1311 graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1312 )
1313 data_id = DataCoordinate.make_empty(self.dimensions)
1314 ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r")
1315 a_ref = a_o.adapt_dataset_ref(ref)
1316 b_ref = b_i.adapt_dataset_ref(ref)
1317 self.assertEqual(a_ref, ref)
1318 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1319 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1320 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1322 @unittest.skipUnless(
1323 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1324 )
1325 def test_component_resolved_by_input(self) -> None:
1326 """Test successful resolution of a component dataset type due to
1327 another input referencing the parent dataset type.
1328 """
1329 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1330 self.b_config.inputs["i"] = DynamicConnectionConfig(
1331 dataset_type_name="d.schema", storage_class="ArrowSchema"
1332 )
1333 graph = self.make_graph()
1334 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1335 graph.resolve(MockRegistry(self.dimensions, {}))
1336 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1337 a_i = graph.tasks["a"].inputs["i"]
1338 b_i = graph.tasks["b"].inputs["i"]
1339 self.assertEqual(b_i.dataset_type_name, "d.schema")
1340 self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1341 self.assertEqual(
1342 b_i.adapt_dataset_type(parent_dataset_type),
1343 parent_dataset_type.makeComponentDatasetType("schema"),
1344 )
1345 data_id = DataCoordinate.make_empty(self.dimensions)
1346 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1347 a_ref = a_i.adapt_dataset_ref(ref)
1348 b_ref = b_i.adapt_dataset_ref(ref)
1349 self.assertEqual(a_ref, ref)
1350 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1351 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1352 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1354 @unittest.skipUnless(
1355 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1356 )
1357 def test_component_resolved_by_output(self) -> None:
1358 """Test successful resolution of a component dataset type due to
1359 an output connection referencing the parent dataset type.
1360 """
1361 self.a_config.outputs["o"] = DynamicConnectionConfig(
1362 dataset_type_name="d", storage_class="ArrowTable"
1363 )
1364 self.b_config.inputs["i"] = DynamicConnectionConfig(
1365 dataset_type_name="d.schema", storage_class="ArrowSchema"
1366 )
1367 graph = self.make_graph()
1368 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1369 graph.resolve(MockRegistry(self.dimensions, {}))
1370 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1371 a_o = graph.tasks["a"].outputs["o"]
1372 b_i = graph.tasks["b"].inputs["i"]
1373 self.assertEqual(b_i.dataset_type_name, "d.schema")
1374 self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1375 self.assertEqual(
1376 b_i.adapt_dataset_type(parent_dataset_type),
1377 parent_dataset_type.makeComponentDatasetType("schema"),
1378 )
1379 data_id = DataCoordinate.make_empty(self.dimensions)
1380 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1381 a_ref = a_o.adapt_dataset_ref(ref)
1382 b_ref = b_i.adapt_dataset_ref(ref)
1383 self.assertEqual(a_ref, ref)
1384 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1385 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1386 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1388 @unittest.skipUnless(
1389 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1390 )
1391 def test_component_resolved_by_registry(self) -> None:
1392 """Test successful resolution of a component dataset type due to
1393 the parent dataset type already being registered.
1394 """
1395 self.b_config.inputs["i"] = DynamicConnectionConfig(
1396 dataset_type_name="d.schema", storage_class="ArrowSchema"
1397 )
1398 graph = self.make_graph()
1399 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1400 graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type}))
1401 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1402 b_i = graph.tasks["b"].inputs["i"]
1403 self.assertEqual(b_i.dataset_type_name, "d.schema")
1404 self.assertEqual(
1405 b_i.adapt_dataset_type(parent_dataset_type),
1406 parent_dataset_type.makeComponentDatasetType("schema"),
1407 )
1408 data_id = DataCoordinate.make_empty(self.dimensions)
1409 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1410 b_ref = b_i.adapt_dataset_ref(ref)
1411 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1412 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1415if __name__ == "__main__":
1416 lsst.utils.tests.init()
1417 unittest.main()