Coverage for tests/test_pipeline_graph.py: 15%
511 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-13 09:52 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-13 09:52 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Tests of things related to the GraphBuilder class."""
30import copy
31import io
32import logging
33import unittest
34from typing import Any
36import lsst.pipe.base.automatic_connection_constants as acc
37import lsst.utils.tests
38from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory
39from lsst.daf.butler.registry import MissingDatasetTypeError
40from lsst.pipe.base.pipeline_graph import (
41 ConnectionTypeConsistencyError,
42 DuplicateOutputError,
43 Edge,
44 EdgesChangedError,
45 IncompatibleDatasetTypeError,
46 NodeKey,
47 NodeType,
48 PipelineGraph,
49 PipelineGraphError,
50 TaskImportMode,
51 UnresolvedGraphError,
52)
53from lsst.pipe.base.tests.mocks import (
54 DynamicConnectionConfig,
55 DynamicTestPipelineTask,
56 DynamicTestPipelineTaskConfig,
57 get_mock_name,
58)
60_LOG = logging.getLogger(__name__)
63class MockRegistry:
64 """A test-utility stand-in for lsst.daf.butler.Registry that just knows
65 how to get dataset types.
66 """
68 def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None:
69 self.dimensions = dimensions
70 self._dataset_types = dataset_types
72 def getDatasetType(self, name: str) -> DatasetType:
73 try:
74 return self._dataset_types[name]
75 except KeyError:
76 raise MissingDatasetTypeError(name) from None
79class PipelineGraphTestCase(unittest.TestCase):
80 """Tests for the `PipelineGraph` class.
82 Tests for `PipelineGraph.resolve` are mostly in
83 `PipelineGraphResolveTestCase` later in this file.
84 """
86 def setUp(self) -> None:
87 # Simple test pipeline has two tasks, 'a' and 'b', with dataset types
88 # 'input', 'intermediate', and 'output'. There are no dimensions on
89 # any of those. We add tasks in reverse order to better test sorting.
90 # There is one labeled task subset, 'only_b', with just 'b' in it.
91 # We copy the configs so the originals (the instance attributes) can
92 # be modified and reused after the ones passed in to the graph are
93 # frozen.
94 self.description = "A pipeline for PipelineGraph unit tests."
95 self.graph = PipelineGraph()
96 self.graph.description = self.description
97 self.b_config = DynamicTestPipelineTaskConfig()
98 self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
99 self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
100 self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1")
101 self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config))
102 self.a_config = DynamicTestPipelineTaskConfig()
103 self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
104 self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1")
105 self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
106 self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config))
107 self.graph.add_task_subset("only_b", ["b"])
108 self.subset_description = "A subset with only task B in it."
109 self.graph.task_subsets["only_b"].description = self.subset_description
110 self.dimensions = DimensionUniverse()
111 self.maxDiff = None
113 def test_unresolved_accessors(self) -> None:
114 """Test attribute accessors, iteration, and simple methods on a graph
115 that has not had `PipelineGraph.resolve` called on it.
116 """
117 self.check_base_accessors(self.graph)
118 self.assertEqual(
119 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)"
120 )
122 def test_sorting(self) -> None:
123 """Test sort methods on PipelineGraph."""
124 self.assertFalse(self.graph.has_been_sorted)
125 self.assertFalse(self.graph.is_sorted)
126 self.graph.sort()
127 self.check_sorted(self.graph)
129 def test_unresolved_xgraph_export(self) -> None:
130 """Test exporting an unresolved PipelineGraph to networkx in various
131 ways.
132 """
133 self.check_make_xgraph(self.graph, resolved=False)
134 self.check_make_bipartite_xgraph(self.graph, resolved=False)
135 self.check_make_task_xgraph(self.graph, resolved=False)
136 self.check_make_dataset_type_xgraph(self.graph, resolved=False)
138 def test_unresolved_stream_io(self) -> None:
139 """Test round-tripping an unresolved PipelineGraph through in-memory
140 serialization.
141 """
142 stream = io.BytesIO()
143 self.graph._write_stream(stream)
144 stream.seek(0)
145 roundtripped = PipelineGraph._read_stream(stream)
146 self.check_make_xgraph(roundtripped, resolved=False)
148 def test_unresolved_file_io(self) -> None:
149 """Test round-tripping an unresolved PipelineGraph through file
150 serialization.
151 """
152 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
153 self.graph._write_uri(filename)
154 roundtripped = PipelineGraph._read_uri(filename)
155 self.check_make_xgraph(roundtripped, resolved=False)
157 def test_unresolved_deferred_import_io(self) -> None:
158 """Test round-tripping an unresolved PipelineGraph through
159 serialization, without immediately importing tasks on read.
160 """
161 stream = io.BytesIO()
162 self.graph._write_stream(stream)
163 stream.seek(0)
164 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
165 self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False)
166 # Check that we can still resolve the graph without importing tasks.
167 roundtripped.resolve(MockRegistry(self.dimensions, {}))
168 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
169 roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES)
170 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
172 def test_resolved_accessors(self) -> None:
173 """Test attribute accessors, iteration, and simple methods on a graph
174 that has had `PipelineGraph.resolve` called on it.
176 This includes the accessors available on unresolved graphs as well as
177 new ones, and we expect the resolved graph to be sorted as well.
178 """
179 self.graph.resolve(MockRegistry(self.dimensions, {}))
180 self.check_base_accessors(self.graph)
181 self.check_sorted(self.graph)
182 self.assertEqual(
183 repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})"
184 )
185 self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty)
186 self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})")
187 self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1"))
188 self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty)
189 self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict")
190 self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict")
192 def test_resolved_xgraph_export(self) -> None:
193 """Test exporting a resolved PipelineGraph to networkx in various
194 ways.
195 """
196 self.graph.resolve(MockRegistry(self.dimensions, {}))
197 self.check_make_xgraph(self.graph, resolved=True)
198 self.check_make_bipartite_xgraph(self.graph, resolved=True)
199 self.check_make_task_xgraph(self.graph, resolved=True)
200 self.check_make_dataset_type_xgraph(self.graph, resolved=True)
202 def test_resolved_stream_io(self) -> None:
203 """Test round-tripping a resolved PipelineGraph through in-memory
204 serialization.
205 """
206 self.graph.resolve(MockRegistry(self.dimensions, {}))
207 stream = io.BytesIO()
208 self.graph._write_stream(stream)
209 stream.seek(0)
210 roundtripped = PipelineGraph._read_stream(stream)
211 self.check_make_xgraph(roundtripped, resolved=True)
213 def test_resolved_file_io(self) -> None:
214 """Test round-tripping a resolved PipelineGraph through file
215 serialization.
216 """
217 self.graph.resolve(MockRegistry(self.dimensions, {}))
218 with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
219 self.graph._write_uri(filename)
220 roundtripped = PipelineGraph._read_uri(filename)
221 self.check_make_xgraph(roundtripped, resolved=True)
223 def test_resolved_deferred_import_io(self) -> None:
224 """Test round-tripping a resolved PipelineGraph through serialization,
225 without immediately importing tasks on read.
226 """
227 self.graph.resolve(MockRegistry(self.dimensions, {}))
228 stream = io.BytesIO()
229 self.graph._write_stream(stream)
230 stream.seek(0)
231 roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT)
232 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
233 roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES)
234 self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
236 def test_unresolved_copies(self) -> None:
237 """Test making copies of an unresolved PipelineGraph."""
238 copy1 = self.graph.copy()
239 self.assertIsNot(copy1, self.graph)
240 self.check_make_xgraph(copy1, resolved=False)
241 copy2 = copy.copy(self.graph)
242 self.assertIsNot(copy2, self.graph)
243 self.check_make_xgraph(copy2, resolved=False)
244 copy3 = copy.deepcopy(self.graph)
245 self.assertIsNot(copy3, self.graph)
246 self.check_make_xgraph(copy3, resolved=False)
248 def test_resolved_copies(self) -> None:
249 """Test making copies of a resolved PipelineGraph."""
250 self.graph.resolve(MockRegistry(self.dimensions, {}))
251 copy1 = self.graph.copy()
252 self.assertIsNot(copy1, self.graph)
253 self.check_make_xgraph(copy1, resolved=True)
254 copy2 = copy.copy(self.graph)
255 self.assertIsNot(copy2, self.graph)
256 self.check_make_xgraph(copy2, resolved=True)
257 copy3 = copy.deepcopy(self.graph)
258 self.assertIsNot(copy3, self.graph)
259 self.check_make_xgraph(copy3, resolved=True)
261 def check_base_accessors(self, graph: PipelineGraph) -> None:
262 """Run parameterized tests that check attribute access, iteration, and
263 simple methods.
265 The given graph must be unchanged from the one defined in `setUp`,
266 other than sorting.
267 """
268 self.assertEqual(graph.description, self.description)
269 self.assertEqual(graph.tasks.keys(), {"a", "b"})
270 self.assertEqual(
271 graph.dataset_types.keys(),
272 {
273 "schema",
274 "input_1",
275 "intermediate_1",
276 "output_1",
277 "a_config",
278 "a_log",
279 "a_metadata",
280 "b_config",
281 "b_log",
282 "b_metadata",
283 },
284 )
285 self.assertEqual(graph.task_subsets.keys(), {"only_b"})
286 self.assertEqual(
287 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)},
288 {
289 (
290 NodeKey(NodeType.DATASET_TYPE, "input_1"),
291 NodeKey(NodeType.TASK, "a"),
292 "input_1 -> a (input1)",
293 ),
294 (
295 NodeKey(NodeType.TASK, "a"),
296 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
297 "a -> intermediate_1 (output1)",
298 ),
299 (
300 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
301 NodeKey(NodeType.TASK, "b"),
302 "intermediate_1 -> b (input1)",
303 ),
304 (
305 NodeKey(NodeType.TASK, "b"),
306 NodeKey(NodeType.DATASET_TYPE, "output_1"),
307 "b -> output_1 (output1)",
308 ),
309 (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"),
310 (
311 NodeKey(NodeType.TASK, "a"),
312 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
313 "a -> a_metadata (_metadata)",
314 ),
315 (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"),
316 (
317 NodeKey(NodeType.TASK, "b"),
318 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
319 "b -> b_metadata (_metadata)",
320 ),
321 },
322 )
323 self.assertEqual(
324 {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)},
325 {
326 (
327 NodeKey(NodeType.TASK_INIT, "a"),
328 NodeKey(NodeType.DATASET_TYPE, "schema"),
329 "a -> schema (out_schema)",
330 ),
331 (
332 NodeKey(NodeType.DATASET_TYPE, "schema"),
333 NodeKey(NodeType.TASK_INIT, "b"),
334 "schema -> b (in_schema)",
335 ),
336 (
337 NodeKey(NodeType.TASK_INIT, "a"),
338 NodeKey(NodeType.DATASET_TYPE, "a_config"),
339 "a -> a_config (_config)",
340 ),
341 (
342 NodeKey(NodeType.TASK_INIT, "b"),
343 NodeKey(NodeType.DATASET_TYPE, "b_config"),
344 "b -> b_config (_config)",
345 ),
346 },
347 )
348 self.assertEqual(
349 {(node_type, name) for node_type, name, _ in graph.iter_nodes()},
350 {
351 NodeKey(NodeType.TASK, "a"),
352 NodeKey(NodeType.TASK, "b"),
353 NodeKey(NodeType.TASK_INIT, "a"),
354 NodeKey(NodeType.TASK_INIT, "b"),
355 NodeKey(NodeType.DATASET_TYPE, "schema"),
356 NodeKey(NodeType.DATASET_TYPE, "input_1"),
357 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
358 NodeKey(NodeType.DATASET_TYPE, "output_1"),
359 NodeKey(NodeType.DATASET_TYPE, "a_config"),
360 NodeKey(NodeType.DATASET_TYPE, "a_log"),
361 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
362 NodeKey(NodeType.DATASET_TYPE, "b_config"),
363 NodeKey(NodeType.DATASET_TYPE, "b_log"),
364 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
365 },
366 )
367 self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"})
368 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"})
369 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"})
370 self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set())
371 self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"})
372 self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"})
373 self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set())
375 self.assertIsNone(graph.producing_edge_of("input_1"))
376 self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a")
377 self.assertEqual(graph.producing_edge_of("output_1").task_label, "b")
378 self.assertIsNone(graph.producer_of("input_1"))
379 self.assertEqual(graph.producer_of("intermediate_1").label, "a")
380 self.assertEqual(graph.producer_of("output_1").label, "b")
382 self.assertEqual(graph.inputs_of("a").keys(), {"input_1"})
383 self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"})
384 self.assertEqual(graph.inputs_of("a", init=True).keys(), set())
385 self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"})
386 self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"})
387 self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"})
388 self.assertEqual(
389 graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"}
390 )
391 self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"})
392 self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"})
393 self.assertEqual(
394 graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"}
395 )
396 self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"})
397 self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set())
399 self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks="))
400 self.assertEqual(
401 repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}"
402 )
404 def check_sorted(self, graph: PipelineGraph) -> None:
405 """Run a battery of tests on a PipelineGraph that must be
406 deterministically sorted.
408 The given graph must be unchanged from the one defined in `setUp`,
409 other than sorting.
410 """
411 self.assertTrue(graph.has_been_sorted)
412 self.assertTrue(graph.is_sorted)
413 self.assertEqual(
414 [(node_type, name) for node_type, name, _ in graph.iter_nodes()],
415 [
416 # We only advertise that the order is topological and
417 # deterministic, so this test is slightly over-specified; there
418 # are other orders that are consistent with our guarantees.
419 NodeKey(NodeType.DATASET_TYPE, "input_1"),
420 NodeKey(NodeType.TASK_INIT, "a"),
421 NodeKey(NodeType.DATASET_TYPE, "a_config"),
422 NodeKey(NodeType.DATASET_TYPE, "schema"),
423 NodeKey(NodeType.TASK_INIT, "b"),
424 NodeKey(NodeType.DATASET_TYPE, "b_config"),
425 NodeKey(NodeType.TASK, "a"),
426 NodeKey(NodeType.DATASET_TYPE, "a_log"),
427 NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
428 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
429 NodeKey(NodeType.TASK, "b"),
430 NodeKey(NodeType.DATASET_TYPE, "b_log"),
431 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
432 NodeKey(NodeType.DATASET_TYPE, "output_1"),
433 ],
434 )
435 # Most users should only care that the tasks and dataset types are
436 # topologically sorted.
437 self.assertEqual(list(graph.tasks), ["a", "b"])
438 self.assertEqual(
439 list(graph.dataset_types),
440 [
441 "input_1",
442 "a_config",
443 "schema",
444 "b_config",
445 "a_log",
446 "a_metadata",
447 "intermediate_1",
448 "b_log",
449 "b_metadata",
450 "output_1",
451 ],
452 )
453 # __str__ and __repr__ of course work on unsorted mapping views, too,
454 # but the order of elements is then nondeterministic and hard to test.
455 self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})")
456 self.assertEqual(
457 repr(self.graph.dataset_types),
458 (
459 "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, "
460 "intermediate_1, b_log, b_metadata, output_1})"
461 ),
462 )
464 def check_make_xgraph(
465 self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True
466 ) -> None:
467 """Check that the given graph exports as expected to networkx.
469 The given graph must be unchanged from the one defined in `setUp`,
470 other than being resolved (if ``resolved=True``) or round-tripped
471 through serialization without tasks being imported (if
472 ``imported_and_configured=False``).
473 """
474 xgraph = graph.make_xgraph()
475 expected_edges = (
476 {edge.key for edge in graph.iter_edges()}
477 | {edge.key for edge in graph.iter_edges(init=True)}
478 | {
479 (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME),
480 (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME),
481 }
482 )
483 test_edges = set(xgraph.edges)
484 self.assertEqual(test_edges, expected_edges)
485 expected_nodes = {
486 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node(
487 "a", resolved, imported_and_configured=imported_and_configured
488 ),
489 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node(
490 "a", resolved, imported_and_configured=imported_and_configured
491 ),
492 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node(
493 "b", resolved, imported_and_configured=imported_and_configured
494 ),
495 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node(
496 "b", resolved, imported_and_configured=imported_and_configured
497 ),
498 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
499 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
500 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
501 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
502 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
503 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
504 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
505 "schema", resolved, is_initial_query_constraint=False
506 ),
507 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
508 "input_1", resolved, is_initial_query_constraint=True
509 ),
510 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
511 "intermediate_1", resolved, is_initial_query_constraint=False
512 ),
513 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
514 "output_1", resolved, is_initial_query_constraint=False
515 ),
516 }
517 test_nodes = dict(xgraph.nodes.items())
518 self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys()))
519 for key, expected_node in expected_nodes.items():
520 test_node = test_nodes[key]
521 self.assertEqual(expected_node, test_node, key)
523 def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
524 """Check that the given graph's init-only or runtime subset exports as
525 expected to networkx.
527 The given graph must be unchanged from the one defined in `setUp`,
528 other than being resolved (if ``resolved=True``).
529 """
530 run_xgraph = graph.make_bipartite_xgraph()
531 self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()})
532 self.assertEqual(
533 dict(run_xgraph.nodes.items()),
534 {
535 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
536 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
537 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
538 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
539 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
540 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
541 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
542 "input_1", resolved, is_initial_query_constraint=True
543 ),
544 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
545 "intermediate_1", resolved, is_initial_query_constraint=False
546 ),
547 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
548 "output_1", resolved, is_initial_query_constraint=False
549 ),
550 },
551 )
552 init_xgraph = graph.make_bipartite_xgraph(
553 init=True,
554 )
555 self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)})
556 self.assertEqual(
557 dict(init_xgraph.nodes.items()),
558 {
559 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
560 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
561 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
562 "schema", resolved, is_initial_query_constraint=False
563 ),
564 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
565 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
566 },
567 )
569 def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
570 """Check that the given graph's task-only projection exports as
571 expected to networkx.
573 The given graph must be unchanged from the one defined in `setUp`,
574 other than being resolved (if ``resolved=True``).
575 """
576 run_xgraph = graph.make_task_xgraph()
577 self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))})
578 self.assertEqual(
579 dict(run_xgraph.nodes.items()),
580 {
581 NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
582 NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
583 },
584 )
585 init_xgraph = graph.make_task_xgraph(
586 init=True,
587 )
588 self.assertEqual(
589 set(init_xgraph.edges),
590 {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))},
591 )
592 self.assertEqual(
593 dict(init_xgraph.nodes.items()),
594 {
595 NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
596 NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
597 },
598 )
600 def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
601 """Check that the given graph's dataset-type-only projection exports as
602 expected to networkx.
604 The given graph must be unchanged from the one defined in `setUp`,
605 other than being resolved (if ``resolved=True``).
606 """
607 run_xgraph = graph.make_dataset_type_xgraph()
608 self.assertEqual(
609 set(run_xgraph.edges),
610 {
611 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")),
612 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")),
613 (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")),
614 (
615 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
616 NodeKey(NodeType.DATASET_TYPE, "output_1"),
617 ),
618 (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")),
619 (
620 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
621 NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
622 ),
623 },
624 )
625 self.assertEqual(
626 dict(run_xgraph.nodes.items()),
627 {
628 NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
629 NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
630 NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
631 NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
632 NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
633 "input_1", resolved, is_initial_query_constraint=True
634 ),
635 NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
636 "intermediate_1", resolved, is_initial_query_constraint=False
637 ),
638 NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
639 "output_1", resolved, is_initial_query_constraint=False
640 ),
641 },
642 )
643 init_xgraph = graph.make_dataset_type_xgraph(init=True)
644 self.assertEqual(
645 set(init_xgraph.edges),
646 {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))},
647 )
648 self.assertEqual(
649 dict(init_xgraph.nodes.items()),
650 {
651 NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
652 "schema", resolved, is_initial_query_constraint=False
653 ),
654 NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
655 NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
656 },
657 )
659 def get_expected_task_node(
660 self, label: str, resolved: bool, imported_and_configured: bool = True
661 ) -> dict[str, Any]:
662 """Construct a networkx-export task node for comparison."""
663 result = self.get_expected_task_init_node(
664 label, resolved, imported_and_configured=imported_and_configured
665 )
666 if resolved:
667 result["dimensions"] = self.dimensions.empty
668 result["raw_dimensions"] = frozenset()
669 return result
671 def get_expected_task_init_node(
672 self, label: str, resolved: bool, imported_and_configured: bool = True
673 ) -> dict[str, Any]:
674 """Construct a networkx-export task init for comparison."""
675 result = {
676 "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask",
677 "bipartite": 1,
678 }
679 if imported_and_configured:
680 result["task_class"] = DynamicTestPipelineTask
681 result["config"] = getattr(self, f"{label}_config")
682 return result
684 def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]:
685 """Construct a networkx-export init-output config dataset type node for
686 comparison.
687 """
688 if not resolved:
689 return {"bipartite": 0}
690 else:
691 return {
692 "dataset_type": DatasetType(
693 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
694 self.dimensions.empty,
695 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
696 ),
697 "is_initial_query_constraint": False,
698 "is_prerequisite": False,
699 "dimensions": self.dimensions.empty,
700 "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
701 "bipartite": 0,
702 }
704 def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]:
705 """Construct a networkx-export output log dataset type node for
706 comparison.
707 """
708 if not resolved:
709 return {"bipartite": 0}
710 else:
711 return {
712 "dataset_type": DatasetType(
713 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
714 self.dimensions.empty,
715 acc.LOG_OUTPUT_STORAGE_CLASS,
716 ),
717 "is_initial_query_constraint": False,
718 "is_prerequisite": False,
719 "dimensions": self.dimensions.empty,
720 "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS,
721 "bipartite": 0,
722 }
724 def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]:
725 """Construct a networkx-export output metadata dataset type node for
726 comparison.
727 """
728 if not resolved:
729 return {"bipartite": 0}
730 else:
731 return {
732 "dataset_type": DatasetType(
733 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
734 self.dimensions.empty,
735 acc.METADATA_OUTPUT_STORAGE_CLASS,
736 ),
737 "is_initial_query_constraint": False,
738 "is_prerequisite": False,
739 "dimensions": self.dimensions.empty,
740 "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS,
741 "bipartite": 0,
742 }
744 def get_expected_connection_node(
745 self, name: str, resolved: bool, *, is_initial_query_constraint: bool
746 ) -> dict[str, Any]:
747 """Construct a networkx-export dataset type node for comparison."""
748 if not resolved:
749 return {"bipartite": 0}
750 else:
751 return {
752 "dataset_type": DatasetType(
753 name,
754 self.dimensions.empty,
755 get_mock_name("StructuredDataDict"),
756 ),
757 "is_initial_query_constraint": is_initial_query_constraint,
758 "is_prerequisite": False,
759 "dimensions": self.dimensions.empty,
760 "storage_class_name": get_mock_name("StructuredDataDict"),
761 "bipartite": 0,
762 }
764 def test_construct_with_data_coordinate(self) -> None:
765 """Test constructing a graph with a DataCoordinate.
767 Since this creates a graph with DimensionUniverse, all tasks added to
768 it should have resolved dimensions, but not (yet) resolved dataset
769 types. We use that to test a few other operations in that state.
770 """
771 data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions)
772 graph = PipelineGraph(data_id=data_id)
773 self.assertEqual(graph.universe, self.dimensions)
774 self.assertEqual(graph.data_id, data_id)
775 graph.add_task("b1", DynamicTestPipelineTask, self.b_config)
776 self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty)
777 # Still can't group by dimensions, because the dataset types aren't
778 # resolved.
779 with self.assertRaises(UnresolvedGraphError):
780 graph.group_by_dimensions()
781 # Transferring a node from this graph to ``self.graph`` should
782 # unresolve the dimensions.
783 self.graph.add_task_nodes([graph.tasks["b1"]])
784 self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"])
785 self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions)
786 # Do the opposite transfer, which should resolve dimensions.
787 graph.add_task_nodes([self.graph.tasks["a"]])
788 self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"])
789 self.assertTrue(graph.tasks["a"].has_resolved_dimensions)
791 def test_group_by_dimensions(self) -> None:
792 """Test PipelineGraph.group_by_dimensions."""
793 with self.assertRaises(UnresolvedGraphError):
794 self.graph.group_by_dimensions()
795 self.a_config.dimensions = ["visit"]
796 self.a_config.outputs["output1"].dimensions = ["visit"]
797 self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig(
798 dataset_type_name="prereq_1",
799 multiple=True,
800 dimensions=["htm7"],
801 is_calibration=True,
802 )
803 self.b_config.dimensions = ["htm7"]
804 self.b_config.inputs["input1"].dimensions = ["visit"]
805 self.b_config.inputs["input1"].multiple = True
806 self.b_config.outputs["output1"].dimensions = ["htm7"]
807 self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config)
808 self.graph.resolve(MockRegistry(self.dimensions, {}))
809 visit_dims = self.dimensions.extract(["visit"])
810 htm7_dims = self.dimensions.extract(["htm7"])
811 expected = {
812 self.dimensions.empty: (
813 {},
814 {
815 "schema": self.graph.dataset_types["schema"],
816 "input_1": self.graph.dataset_types["input_1"],
817 "a_config": self.graph.dataset_types["a_config"],
818 "b_config": self.graph.dataset_types["b_config"],
819 },
820 ),
821 visit_dims: (
822 {"a": self.graph.tasks["a"]},
823 {
824 "a_log": self.graph.dataset_types["a_log"],
825 "a_metadata": self.graph.dataset_types["a_metadata"],
826 "intermediate_1": self.graph.dataset_types["intermediate_1"],
827 },
828 ),
829 htm7_dims: (
830 {"b": self.graph.tasks["b"]},
831 {
832 "b_log": self.graph.dataset_types["b_log"],
833 "b_metadata": self.graph.dataset_types["b_metadata"],
834 "output_1": self.graph.dataset_types["output_1"],
835 },
836 ),
837 }
838 self.assertEqual(self.graph.group_by_dimensions(), expected)
839 expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"]
840 self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected)
842 def test_add_and_remove(self) -> None:
843 """Tests for adding and removing tasks and task subsets from a
844 PipelineGraph.
845 """
846 # Can't remove a task while it's still in a subset.
847 with self.assertRaises(PipelineGraphError):
848 self.graph.remove_tasks(["b"], drop_from_subsets=False)
849 # ...unless you remove the subset.
850 self.graph.remove_task_subset("only_b")
851 self.assertFalse(self.graph.task_subsets)
852 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False)
853 self.assertFalse(referencing_subsets)
854 self.assertEqual(self.graph.tasks.keys(), {"a"})
855 # Add that task back in.
856 self.graph.add_task_nodes([b])
857 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
858 # Add the subset back in.
859 self.graph.add_task_subset("only_b", {"b"})
860 self.assertEqual(self.graph.task_subsets.keys(), {"only_b"})
861 # Resolve the graph's dataset types and task dimensions.
862 self.graph.resolve(MockRegistry(self.dimensions, {}))
863 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
864 self.assertTrue(self.graph.dataset_types.is_resolved("output_1"))
865 self.assertTrue(self.graph.dataset_types.is_resolved("schema"))
866 self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1"))
867 # Remove the task while removing it from the subset automatically. This
868 # should also unresolve (only) the referenced dataset types and drop
869 # any datasets no longer attached to any task.
870 self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
871 ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True)
872 self.assertEqual(referencing_subsets, {"only_b"})
873 self.assertEqual(self.graph.tasks.keys(), {"a"})
874 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
875 self.assertNotIn("output1", self.graph.dataset_types)
876 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
877 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
879 def test_reconfigure(self) -> None:
880 """Tests for PipelineGraph.reconfigure."""
881 self.graph.resolve(MockRegistry(self.dimensions, {}))
882 self.b_config.outputs["output1"].storage_class = "TaskMetadata"
883 with self.assertRaises(ValueError):
884 # Can't check and assume together.
885 self.graph.reconfigure_tasks(
886 b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True
887 )
888 # Check that graph is unchanged after error.
889 self.check_base_accessors(self.graph)
890 with self.assertRaises(EdgesChangedError):
891 self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True)
892 self.check_base_accessors(self.graph)
893 # Make a change that does affect edges; this will unresolve most
894 # dataset types.
895 self.graph.reconfigure_tasks(b=self.b_config)
896 self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
897 self.assertFalse(self.graph.dataset_types.is_resolved("output_1"))
898 self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
899 self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
900 # Resolving again will pick up the new storage class
901 self.graph.resolve(MockRegistry(self.dimensions, {}))
902 self.assertEqual(
903 self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata")
904 )
907def _have_example_storage_classes() -> bool:
908 """Check whether some storage classes work as expected.
910 Given that these have registered converters, it shouldn't actually be
911 necessary to import be able to those types in order to determine that
912 they're convertible, but the storage class machinery is implemented such
913 that types that can't be imported can't be converted, and while that's
914 inconvenient here it's totally fine in non-testing scenarios where you only
915 care about a storage class if you can actually use it.
916 """
917 getter = StorageClassFactory().getStorageClass
918 return (
919 getter("ArrowTable").can_convert(getter("ArrowAstropy"))
920 and getter("ArrowAstropy").can_convert(getter("ArrowTable"))
921 and getter("ArrowTable").can_convert(getter("DataFrame"))
922 and getter("DataFrame").can_convert(getter("ArrowTable"))
923 )
926class PipelineGraphResolveTestCase(unittest.TestCase):
927 """More extensive tests for PipelineGraph.resolve and its primate helper
928 methods.
930 These are in a separate TestCase because they utilize a different `setUp`
931 from the rest of the `PipelineGraph` tests.
932 """
934 def setUp(self) -> None:
935 self.a_config = DynamicTestPipelineTaskConfig()
936 self.b_config = DynamicTestPipelineTaskConfig()
937 self.dimensions = DimensionUniverse()
938 self.maxDiff = None
940 def make_graph(self) -> PipelineGraph:
941 graph = PipelineGraph()
942 graph.add_task("a", DynamicTestPipelineTask, self.a_config)
943 graph.add_task("b", DynamicTestPipelineTask, self.b_config)
944 return graph
946 def test_prerequisite_inconsistency(self) -> None:
947 """Test that we raise an exception when one edge defines a dataset type
948 as a prerequisite and another does not.
950 This test will hopefully someday go away (along with
951 `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation
952 algorithm becomes more flexible.
953 """
954 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
955 self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
956 graph = self.make_graph()
957 with self.assertRaises(ConnectionTypeConsistencyError):
958 graph.resolve(MockRegistry(self.dimensions, {}))
960 def test_prerequisite_inconsistency_reversed(self) -> None:
961 """Same as `test_prerequisite_inconsistency`, with the order the edges
962 are added to the graph reversed.
963 """
964 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
965 self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
966 graph = self.make_graph()
967 with self.assertRaises(ConnectionTypeConsistencyError):
968 graph.resolve(MockRegistry(self.dimensions, {}))
970 def test_prerequisite_output(self) -> None:
971 """Test that we raise an exception when one edge defines a dataset type
972 as a prerequisite but another defines it as an output.
973 """
974 self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
975 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
976 graph = self.make_graph()
977 with self.assertRaises(ConnectionTypeConsistencyError):
978 graph.resolve(MockRegistry(self.dimensions, {}))
980 def test_skypix_missing(self) -> None:
981 """Test that we raise an exception when one edge uses the "skypix"
982 dimension as a placeholder but the dataset type is not registered.
983 """
984 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
985 dataset_type_name="d", dimensions={"skypix"}
986 )
987 graph = self.make_graph()
988 with self.assertRaises(MissingDatasetTypeError):
989 graph.resolve(MockRegistry(self.dimensions, {}))
991 def test_skypix_inconsistent(self) -> None:
992 """Test that we raise an exception when one edge uses the "skypix"
993 dimension as a placeholder but the rest of the dimensions are
994 inconsistent with the registered dataset type.
995 """
996 self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
997 dataset_type_name="d", dimensions={"skypix", "visit"}
998 )
999 graph = self.make_graph()
1000 with self.assertRaises(IncompatibleDatasetTypeError):
1001 graph.resolve(
1002 MockRegistry(
1003 self.dimensions,
1004 {
1005 "d": DatasetType(
1006 "d",
1007 dimensions=self.dimensions.extract(["htm7"]),
1008 storageClass="StructuredDataDict",
1009 )
1010 },
1011 )
1012 )
1013 with self.assertRaises(IncompatibleDatasetTypeError):
1014 graph.resolve(
1015 MockRegistry(
1016 self.dimensions,
1017 {
1018 "d": DatasetType(
1019 "d",
1020 dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]),
1021 storageClass="StructuredDataDict",
1022 )
1023 },
1024 )
1025 )
1027 def test_duplicate_outputs(self) -> None:
1028 """Test that we raise an exception when a dataset type node would have
1029 two write edges.
1030 """
1031 self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1032 self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
1033 graph = self.make_graph()
1034 with self.assertRaises(DuplicateOutputError):
1035 graph.resolve(MockRegistry(self.dimensions, {}))
1037 def test_component_of_unregistered_parent(self) -> None:
1038 """Test that we raise an exception when a component dataset type's
1039 parent is not registered.
1040 """
1041 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1042 graph = self.make_graph()
1043 with self.assertRaises(MissingDatasetTypeError):
1044 graph.resolve(MockRegistry(self.dimensions, {}))
1046 def test_undefined_component(self) -> None:
1047 """Test that we raise an exception when a component dataset type's
1048 parent is registered, but its storage class does not have that
1049 component.
1050 """
1051 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
1052 graph = self.make_graph()
1053 with self.assertRaises(IncompatibleDatasetTypeError):
1054 graph.resolve(
1055 MockRegistry(
1056 self.dimensions,
1057 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1058 )
1059 )
1061 @unittest.skipUnless(
1062 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1063 )
1064 def test_bad_component_storage_class(self) -> None:
1065 """Test that we raise an exception when a component dataset type's
1066 parent is registered, but does not have that component.
1067 """
1068 self.a_config.inputs["i"] = DynamicConnectionConfig(
1069 dataset_type_name="d.schema", storage_class="StructuredDataDict"
1070 )
1071 graph = self.make_graph()
1072 with self.assertRaises(IncompatibleDatasetTypeError):
1073 graph.resolve(
1074 MockRegistry(
1075 self.dimensions,
1076 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))},
1077 )
1078 )
1080 def test_input_storage_class_incompatible_with_registry(self) -> None:
1081 """Test that we raise an exception when an input connection's storage
1082 class is incompatible with the registry definition.
1083 """
1084 self.a_config.inputs["i"] = DynamicConnectionConfig(
1085 dataset_type_name="d", storage_class="StructuredDataList"
1086 )
1087 graph = self.make_graph()
1088 with self.assertRaises(IncompatibleDatasetTypeError):
1089 graph.resolve(
1090 MockRegistry(
1091 self.dimensions,
1092 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1093 )
1094 )
1096 def test_output_storage_class_incompatible_with_registry(self) -> None:
1097 """Test that we raise an exception when an output connection's storage
1098 class is incompatible with the registry definition.
1099 """
1100 self.a_config.outputs["o"] = DynamicConnectionConfig(
1101 dataset_type_name="d", storage_class="StructuredDataList"
1102 )
1103 graph = self.make_graph()
1104 with self.assertRaises(IncompatibleDatasetTypeError):
1105 graph.resolve(
1106 MockRegistry(
1107 self.dimensions,
1108 {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
1109 )
1110 )
1112 def test_input_storage_class_incompatible_with_output(self) -> None:
1113 """Test that we raise an exception when an input connection's storage
1114 class is incompatible with the storage class of the output connection.
1115 """
1116 self.a_config.outputs["o"] = DynamicConnectionConfig(
1117 dataset_type_name="d", storage_class="StructuredDataDict"
1118 )
1119 self.b_config.inputs["i"] = DynamicConnectionConfig(
1120 dataset_type_name="d", storage_class="StructuredDataList"
1121 )
1122 graph = self.make_graph()
1123 with self.assertRaises(IncompatibleDatasetTypeError):
1124 graph.resolve(MockRegistry(self.dimensions, {}))
1126 def test_ambiguous_storage_class(self) -> None:
1127 """Test that we raise an exception when two input connections define
1128 the same dataset with different storage classes (even compatible ones)
1129 and there is no output connection or registry definition to take
1130 precedence.
1131 """
1132 self.a_config.inputs["i"] = DynamicConnectionConfig(
1133 dataset_type_name="d", storage_class="StructuredDataDict"
1134 )
1135 self.b_config.inputs["i"] = DynamicConnectionConfig(
1136 dataset_type_name="d", storage_class="StructuredDataList"
1137 )
1138 graph = self.make_graph()
1139 with self.assertRaises(MissingDatasetTypeError):
1140 graph.resolve(MockRegistry(self.dimensions, {}))
1142 @unittest.skipUnless(
1143 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1144 )
1145 def test_inputs_compatible_with_registry(self) -> None:
1146 """Test successful resolution of a dataset type where input edges have
1147 different but compatible storage classes and the dataset type is
1148 already registered.
1149 """
1150 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1151 self.b_config.inputs["i"] = DynamicConnectionConfig(
1152 dataset_type_name="d", storage_class="ArrowAstropy"
1153 )
1154 graph = self.make_graph()
1155 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1156 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1157 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1158 a_i = graph.tasks["a"].inputs["i"]
1159 b_i = graph.tasks["b"].inputs["i"]
1160 self.assertEqual(
1161 a_i.adapt_dataset_type(dataset_type),
1162 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1163 )
1164 self.assertEqual(
1165 b_i.adapt_dataset_type(dataset_type),
1166 dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1167 )
1168 data_id = DataCoordinate.makeEmpty(self.dimensions)
1169 ref = DatasetRef(dataset_type, data_id, run="r")
1170 a_ref = a_i.adapt_dataset_ref(ref)
1171 b_ref = b_i.adapt_dataset_ref(ref)
1172 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1173 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1174 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1175 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1177 @unittest.skipUnless(
1178 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1179 )
1180 def test_output_compatible_with_registry(self) -> None:
1181 """Test successful resolution of a dataset type where an output edge
1182 has a different but compatible storage class from the dataset type
1183 already registered.
1184 """
1185 self.a_config.outputs["o"] = DynamicConnectionConfig(
1186 dataset_type_name="d", storage_class="ArrowTable"
1187 )
1188 graph = self.make_graph()
1189 dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
1190 graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
1191 self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
1192 a_o = graph.tasks["a"].outputs["o"]
1193 self.assertEqual(
1194 a_o.adapt_dataset_type(dataset_type),
1195 dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
1196 )
1197 data_id = DataCoordinate.makeEmpty(self.dimensions)
1198 ref = DatasetRef(dataset_type, data_id, run="r")
1199 a_ref = a_o.adapt_dataset_ref(ref)
1200 self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
1201 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1203 @unittest.skipUnless(
1204 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1205 )
1206 def test_inputs_compatible_with_output(self) -> None:
1207 """Test successful resolution of a dataset type where an input edge has
1208 a different but compatible storage class from the output edge, and
1209 the dataset type is not registered.
1210 """
1211 self.a_config.outputs["o"] = DynamicConnectionConfig(
1212 dataset_type_name="d", storage_class="ArrowTable"
1213 )
1214 self.b_config.inputs["i"] = DynamicConnectionConfig(
1215 dataset_type_name="d", storage_class="ArrowAstropy"
1216 )
1217 graph = self.make_graph()
1218 a_o = graph.tasks["a"].outputs["o"]
1219 b_i = graph.tasks["b"].inputs["i"]
1220 graph.resolve(MockRegistry(self.dimensions, {}))
1221 self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable"))
1222 self.assertEqual(
1223 a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1224 graph.dataset_types["d"].dataset_type,
1225 )
1226 self.assertEqual(
1227 b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
1228 graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
1229 )
1230 data_id = DataCoordinate.makeEmpty(self.dimensions)
1231 ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r")
1232 a_ref = a_o.adapt_dataset_ref(ref)
1233 b_ref = b_i.adapt_dataset_ref(ref)
1234 self.assertEqual(a_ref, ref)
1235 self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
1236 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1237 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1239 @unittest.skipUnless(
1240 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1241 )
1242 def test_component_resolved_by_input(self) -> None:
1243 """Test successful resolution of a component dataset type due to
1244 another input referencing the parent dataset type.
1245 """
1246 self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
1247 self.b_config.inputs["i"] = DynamicConnectionConfig(
1248 dataset_type_name="d.schema", storage_class="ArrowSchema"
1249 )
1250 graph = self.make_graph()
1251 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1252 graph.resolve(MockRegistry(self.dimensions, {}))
1253 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1254 a_i = graph.tasks["a"].inputs["i"]
1255 b_i = graph.tasks["b"].inputs["i"]
1256 self.assertEqual(b_i.dataset_type_name, "d.schema")
1257 self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1258 self.assertEqual(
1259 b_i.adapt_dataset_type(parent_dataset_type),
1260 parent_dataset_type.makeComponentDatasetType("schema"),
1261 )
1262 data_id = DataCoordinate.makeEmpty(self.dimensions)
1263 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1264 a_ref = a_i.adapt_dataset_ref(ref)
1265 b_ref = b_i.adapt_dataset_ref(ref)
1266 self.assertEqual(a_ref, ref)
1267 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1268 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1269 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1271 @unittest.skipUnless(
1272 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1273 )
1274 def test_component_resolved_by_output(self) -> None:
1275 """Test successful resolution of a component dataset type due to
1276 an output connection referencing the parent dataset type.
1277 """
1278 self.a_config.outputs["o"] = DynamicConnectionConfig(
1279 dataset_type_name="d", storage_class="ArrowTable"
1280 )
1281 self.b_config.inputs["i"] = DynamicConnectionConfig(
1282 dataset_type_name="d.schema", storage_class="ArrowSchema"
1283 )
1284 graph = self.make_graph()
1285 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1286 graph.resolve(MockRegistry(self.dimensions, {}))
1287 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1288 a_o = graph.tasks["a"].outputs["o"]
1289 b_i = graph.tasks["b"].inputs["i"]
1290 self.assertEqual(b_i.dataset_type_name, "d.schema")
1291 self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
1292 self.assertEqual(
1293 b_i.adapt_dataset_type(parent_dataset_type),
1294 parent_dataset_type.makeComponentDatasetType("schema"),
1295 )
1296 data_id = DataCoordinate.makeEmpty(self.dimensions)
1297 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1298 a_ref = a_o.adapt_dataset_ref(ref)
1299 b_ref = b_i.adapt_dataset_ref(ref)
1300 self.assertEqual(a_ref, ref)
1301 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1302 self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
1303 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1305 @unittest.skipUnless(
1306 _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
1307 )
1308 def test_component_resolved_by_registry(self) -> None:
1309 """Test successful resolution of a component dataset type due to
1310 the parent dataset type already being registered.
1311 """
1312 self.b_config.inputs["i"] = DynamicConnectionConfig(
1313 dataset_type_name="d.schema", storage_class="ArrowSchema"
1314 )
1315 graph = self.make_graph()
1316 parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
1317 graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type}))
1318 self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
1319 b_i = graph.tasks["b"].inputs["i"]
1320 self.assertEqual(b_i.dataset_type_name, "d.schema")
1321 self.assertEqual(
1322 b_i.adapt_dataset_type(parent_dataset_type),
1323 parent_dataset_type.makeComponentDatasetType("schema"),
1324 )
1325 data_id = DataCoordinate.makeEmpty(self.dimensions)
1326 ref = DatasetRef(parent_dataset_type, data_id, run="r")
1327 b_ref = b_i.adapt_dataset_ref(ref)
1328 self.assertEqual(b_ref, ref.makeComponentRef("schema"))
1329 self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
1332if __name__ == "__main__":
1333 lsst.utils.tests.init()
1334 unittest.main()